From f1de9432d4280fbaf883939b4ebd6408ee015c81 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 28 Dec 2023 21:57:16 -0500 Subject: [PATCH 01/29] dummy execution functions --- stable-diffusion.cpp | 697 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 697 insertions(+) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c1ffdc804..44dbe85ac 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -4948,6 +4948,701 @@ struct ESRGAN { } }; +/* + =================================== ControlNet =================================== + Reference: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/cldm/cldm.py + +*/ + +struct CNHintBlock { + int hint_channels = 3; + int model_channels = 320; // SD 1.5 + int feat_channels[4] = { 16, 32, 96, 256 }; + int num_blocks = 3; + ggml_tensor* conv_first_w; // [feat_channels[0], hint_channels, 3, 3] + ggml_tensor* conv_first_b; // [feat_channels[0]] + + struct hint_block { + ggml_tensor* conv_0_w; // [feat_channels[idx], feat_channels[idx], 3, 3] + ggml_tensor* conv_0_b; // [feat_channels[idx]] + + ggml_tensor* conv_1_w; // [feat_channels[idx + 1], feat_channels[idx], 3, 3] + ggml_tensor* conv_1_b; // [feat_channels[idx + 1]] + }; + + hint_block blocks[3]; + ggml_tensor* conv_final_w; // [feat_channels[3], model_channels, 3, 3] + ggml_tensor* conv_final_b; // [feat_channels[3]] + + size_t calculate_mem_size() { + size_t mem_size = feat_channels[0] * hint_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += feat_channels[0] * ggml_type_size(GGML_TYPE_F32); // conv_first_b + for (int i = 0; i < num_blocks; i++) { + mem_size += feat_channels[i] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_0_w + mem_size += feat_channels[i] * ggml_type_size(GGML_TYPE_F32); // conv_0_b + mem_size += feat_channels[i + 1] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w + mem_size += feat_channels[i + 1] * ggml_type_size(GGML_TYPE_F32); // conv_1_b + } + mem_size += model_channels * feat_channels[3] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_final_w + mem_size += model_channels * ggml_type_size(GGML_TYPE_F32); // conv_final_b + return static_cast(mem_size); + } + + void init_params(struct ggml_context* ctx) { + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, hint_channels, feat_channels[0]); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[0]); + + for (int i = 0; i < num_blocks; i++) { + blocks[i].conv_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i]); + blocks[i].conv_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i]); + blocks[i].conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i + 1]); + blocks[i].conv_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i + 1]); + } + + conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[3], model_channels); + conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model_channels); + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "input_hint_block.0.weight"] = conv_first_w; + tensors[prefix + "input_hint_block.0.bias"] = conv_first_b; + int index = 2; + for (int i = 0; i < num_blocks; i++) { + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_0_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_0_b; + index += 2; + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_1_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_1_b; + index += 2; + } + tensors[prefix + "input_hint_block.14.weight"] = conv_final_w; + tensors[prefix + "input_hint_block.14.bias"] = conv_final_b; + } + + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) { + ggml_tensor* h; + h = ggml_nn_conv_2d(ctx, x, conv_first_w, conv_first_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + + for(int i = 0; i < num_blocks; i++) { + // operations.conv_nd(dims, 16, 16, 3, padding=1) + h = ggml_nn_conv_2d(ctx, h, blocks[i].conv_0_w, blocks[i].conv_0_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + + // operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2) + h = ggml_nn_conv_2d(ctx, h, blocks[i].conv_1_w, blocks[i].conv_1_b, 2, 2, 1, 1); + h = ggml_silu_inplace(ctx, h); + } + + h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + return h; + } +}; + +struct CNZeroConv { + int channels; + ggml_tensor* conv_w; // [channels, channels, 1, 1] + ggml_tensor* conv_b; // [channels] + + void init_params(struct ggml_context* ctx) { + conv_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels,channels); + conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + } +}; + +struct ControlNet { + int in_channels = 4; + int model_channels = 320; + int out_channels = 4; + int num_res_blocks = 2; + std::vector attention_resolutions = {4, 2, 1}; + std::vector channel_mult = {1, 2, 4, 4}; + std::vector transformer_depth = {1, 1, 1, 1}; + int time_embed_dim = 1280; // model_channels*4 + int num_heads = 8; + int num_head_channels = -1; // channels // num_heads + int context_dim = 768; + int middle_out_channel; + CNHintBlock input_hint_block; + CNZeroConv zero_convs[12]; + int num_zero_convs = 1; + + // network params + struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels] + struct ggml_tensor* time_embed_0_b; // [time_embed_dim, ] + // time_embed_1 is nn.SILU() + struct ggml_tensor* time_embed_2_w; // [time_embed_dim, time_embed_dim] + struct ggml_tensor* time_embed_2_b; // [time_embed_dim, ] + + struct ggml_tensor* input_block_0_w; // [model_channels, in_channels, 3, 3] + struct ggml_tensor* input_block_0_b; // [model_channels, ] + + // input_blocks + ResBlock input_res_blocks[4][2]; + SpatialTransformer input_transformers[3][2]; + DownSample input_down_samples[3]; + + // middle_block + ResBlock middle_block_0; + SpatialTransformer middle_block_1; + ResBlock middle_block_2; + + struct ggml_tensor* middle_block_out_w; // [middle_out_channel, middle_out_channel, 1, 1] + struct ggml_tensor* middle_block_out_b; // [middle_out_channel, ] + + struct ggml_context* ctx; + ggml_backend_buffer_t params_buffer; + ggml_backend_buffer_t compute_buffer; // for compute + struct ggml_allocr* compute_alloc = NULL; + size_t compute_memory_buffer_size = -1; + + size_t memory_buffer_size = 0; + ggml_type wtype; + ggml_backend_t backend = NULL; + + ControlNet() { + // input_blocks + std::vector input_block_chans; + input_block_chans.push_back(model_channels); + int ch = model_channels; + zero_convs[0].channels = model_channels; + int ds = 1; + + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].channels = ch; + input_res_blocks[i][j].emb_channels = time_embed_dim; + input_res_blocks[i][j].out_channels = mult * model_channels; + + ch = mult * model_channels; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + input_transformers[i][j] = SpatialTransformer(transformer_depth[i]); + input_transformers[i][j].in_channels = ch; + input_transformers[i][j].n_head = n_head; + input_transformers[i][j].d_head = d_head; + input_transformers[i][j].context_dim = context_dim; + } + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + } + if (i != len_mults - 1) { + input_down_samples[i].channels = ch; + input_down_samples[i].out_channels = ch; + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + ds *= 2; + } + } + GGML_ASSERT(num_zero_convs == 12); + + // middle blocks + middle_block_0.channels = ch; + middle_block_0.emb_channels = time_embed_dim; + middle_block_0.out_channels = ch; + + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + middle_block_1 = SpatialTransformer(transformer_depth[transformer_depth.size() - 1]); + middle_block_1.in_channels = ch; + middle_block_1.n_head = n_head; + middle_block_1.d_head = d_head; + middle_block_1.context_dim = context_dim; + + middle_block_2.channels = ch; + middle_block_2.emb_channels = time_embed_dim; + middle_block_2.out_channels = ch; + middle_out_channel = ch; + } + + size_t calculate_mem_size() { + double mem_size = 0; + mem_size += input_hint_block.calculate_mem_size(); + mem_size += time_embed_dim * model_channels * ggml_type_sizef(wtype); // time_embed_0_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_0_b + mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + + mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w + mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b + + // input_blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + mem_size += input_res_blocks[i][j].calculate_mem_size(wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + mem_size += input_transformers[i][j].calculate_mem_size(wtype); + } + } + if (i != len_mults - 1) { + ds *= 2; + mem_size += input_down_samples[i].calculate_mem_size(wtype); + } + } + + for (int i = 0; i < num_zero_convs; i++) { + mem_size += zero_convs[i].channels * zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F16); + mem_size += zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F32); + } + + // middle_block + mem_size += middle_block_0.calculate_mem_size(wtype); + mem_size += middle_block_1.calculate_mem_size(wtype); + mem_size += middle_block_2.calculate_mem_size(wtype); + + mem_size += middle_out_channel * middle_out_channel * ggml_type_sizef(GGML_TYPE_F16); // middle_block_out_w + mem_size += middle_out_channel * ggml_type_sizef(GGML_TYPE_F32); // middle_block_out_b + + return static_cast(mem_size); + } + + int get_num_tensors() { + // in + int num_tensors = 6; + + num_tensors += num_zero_convs * 2; + + // input blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + num_tensors += 12; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + num_tensors += input_transformers[i][j].get_num_tensors(); + } + } + if (i != len_mults - 1) { + ds *= 2; + num_tensors += 2; + } + } + + // middle blocks + num_tensors += 13 * 2; + num_tensors += middle_block_1.get_num_tensors(); + return num_tensors; + } + + bool init(ggml_backend_t backend_, ggml_type wtype_) { + backend = backend_; + wtype = wtype_; + memory_buffer_size = 4 * 1024 * 1024; // 4 MB, for padding + memory_buffer_size += calculate_mem_size(); + int num_tensors = get_num_tensors(); + + LOG_DEBUG("control net params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); + + struct ggml_init_params params; + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()) + 1 * 1024 * 1024; + params.mem_buffer = NULL; + params.no_alloc = true; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + + params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); + return true; + } + + void alloc_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + + input_hint_block.init_params(ctx); + + time_embed_0_w = ggml_new_tensor_2d(ctx, wtype, model_channels, time_embed_dim); + time_embed_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); + time_embed_2_w = ggml_new_tensor_2d(ctx, wtype, time_embed_dim, time_embed_dim); + time_embed_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); + + // input_blocks + input_block_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, model_channels); + input_block_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model_channels); + + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].init_params(ctx, wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].init_params(ctx, alloc, wtype); + } + } + if (i != len_mults - 1) { + input_down_samples[i].init_params(ctx, wtype); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + zero_convs[i].init_params(ctx); + } + + // middle_blocks + middle_block_0.init_params(ctx, wtype); + middle_block_1.init_params(ctx, alloc, wtype); + middle_block_2.init_params(ctx, wtype); + + // middle_block_out + middle_block_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, middle_out_channel, middle_out_channel); + middle_block_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, middle_out_channel); + + // alloc all tensors linked to this context + for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } + } + + ggml_allocr_free(alloc); + } + + bool load_from_file(const std::string& file_path, ggml_backend_t backend, ggml_type wtype) { + LOG_INFO("loading control net from '%s'", file_path.c_str()); + + std::map control_tensors; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + if (!init(backend, wtype)) { + return false; + } + + // prepare memory for the weights + { + alloc_params(); + map_by_name(control_tensors, ""); + } + + std::set tensor_names_in_file; + + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + tensor_names_in_file.insert(name); + + struct ggml_tensor* real; + if (control_tensors.find(name) != control_tensors.end()) { + real = control_tensors[name]; + } else { + LOG_ERROR("unknown tensor '%s' in model file", name.data()); + return true; + } + + if ( + real->ne[0] != tensor_storage.ne[0] || + real->ne[1] != tensor_storage.ne[1] || + real->ne[2] != tensor_storage.ne[2] || + real->ne[3] != tensor_storage.ne[3]) { + LOG_ERROR( + "tensor '%s' has wrong shape in model file: " + "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", + name.c_str(), + (int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3], + (int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]); + return false; + } + + *dst_tensor = real; + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); + + bool some_tensor_not_init = false; + + for (auto pair : control_tensors) { + if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) { + LOG_ERROR("tensor '%s' not in model file", pair.first.c_str()); + some_tensor_not_init = true; + } + } + + if (some_tensor_not_init) { + return false; + } + + LOG_INFO("control net model loaded"); + return success; + } + + void map_by_name(std::map& tensors, const std::string prefix) { + input_hint_block.map_by_name(tensors, ""); + tensors[prefix + "time_embed.0.weight"] = time_embed_0_w; + tensors[prefix + "time_embed.0.bias"] = time_embed_0_b; + tensors[prefix + "time_embed.2.weight"] = time_embed_2_w; + tensors[prefix + "time_embed.2.bias"] = time_embed_2_b; + + // input_blocks + tensors[prefix + "input_blocks.0.0.weight"] = input_block_0_w; + tensors[prefix + "input_blocks.0.0.bias"] = input_block_0_b; + + int len_mults = channel_mult.size(); + int input_block_idx = 0; + int ds = 1; + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_block_idx += 1; + input_res_blocks[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".1."); + } + } + if (i != len_mults - 1) { + input_block_idx += 1; + input_down_samples[i].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.weight"] = zero_convs[i].conv_w; + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.bias"] = zero_convs[i].conv_b; + } + + // middle_blocks + middle_block_0.map_by_name(tensors, prefix + "middle_block.0."); + middle_block_1.map_by_name(tensors, prefix + "middle_block.1."); + middle_block_2.map_by_name(tensors, prefix + "middle_block.2."); + + tensors[prefix + "middle_block_out.0.weight"] = middle_block_out_w; + tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; + } + + struct std::vector forward(struct ggml_cgraph* gf, struct ggml_context* ctx0, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // t_emb: [N, model_channels] + // context: [N, max_position, hidden_size]([N, 77, 768]) + // y: [adm_in_channels] + if (t_emb == NULL && timesteps != NULL) { + t_emb = new_timestep_embedding(ctx0, compute_alloc, timesteps, model_channels); // [N, model_channels] + } + + // time_embed = nn.Sequential + auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b); + emb = ggml_silu_inplace(ctx0, emb); + emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] + + auto guided_hint = input_hint_block.forward(ctx0, hint); + + // input_blocks + std::vector outs; + int zero_conv_offset = 0; + + // input block 0 + struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] + h = ggml_add(ctx0, h, guided_hint); + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + zero_conv_offset++; + outs.push_back(h); + + // input block 1-11 + int len_mults = channel_mult.size(); + int ds = 1; + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + h = input_res_blocks[i][j].forward(ctx0, h, emb); // [N, mult*model_channels, h, w] + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] + } + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + zero_conv_offset++; + outs.push_back(h); + } + if (i != len_mults - 1) { + ds *= 2; + h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + zero_conv_offset++; + outs.push_back(h); + } + } + // [N, 4*model_channels, h/8, w/8] + + // middle_block + h = middle_block_0.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] + h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + + h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + outs.push_back(h); + ggml_build_forward_expand(gf, h); + return outs; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + struct ggml_context* ctx0 = ggml_init(params); + + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, UNET_GRAPH_SIZE, false); + + // temporal tensors for transfer tensors from cpu to gpu if needed + struct ggml_tensor* x_t = NULL; + struct ggml_tensor* timesteps_t = NULL; + struct ggml_tensor* context_t = NULL; + struct ggml_tensor* t_emb_t = NULL; + struct ggml_tensor* y_t = NULL; + + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + x_t = ggml_dup_tensor(ctx0, x); + context_t = ggml_dup_tensor(ctx0, context); + ggml_allocr_alloc(compute_alloc, x_t); + if (timesteps != NULL) { + timesteps_t = ggml_dup_tensor(ctx0, timesteps); + ggml_allocr_alloc(compute_alloc, timesteps_t); + } + ggml_allocr_alloc(compute_alloc, context_t); + if (t_emb != NULL) { + t_emb_t = ggml_dup_tensor(ctx0, t_emb); + ggml_allocr_alloc(compute_alloc, t_emb_t); + } + if (y != NULL) { + y_t = ggml_dup_tensor(ctx0, y); + ggml_allocr_alloc(compute_alloc, y_t); + } + // pass data to device backend + if (!ggml_allocr_is_measure(compute_alloc)) { + ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + if (timesteps_t != NULL) { + ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + } + if (t_emb_t != NULL) { + ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + } + if (y != NULL) { + ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); + } + } + } else { + // if it's cpu backend just pass the same tensors + x_t = x; + timesteps_t = timesteps; + context_t = context; + t_emb_t = t_emb; + y_t = y; + } + + struct std::vector outs = forward(gf, ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); + + ggml_free(ctx0); + + return gf; + } + + void begin(struct ggml_tensor* x, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { + if (compute_memory_buffer_size == -1) { + // alignment required by the backend + compute_alloc = ggml_allocr_new_measure_from_backend(backend); + + struct ggml_cgraph* gf = build_graph(x, NULL, context, t_emb, y); + + // compute the required memory + compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(compute_alloc); + + LOG_DEBUG("diffusion compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + } + + compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); + compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); + } + + void compute(struct ggml_tensor* work_latent, + int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { + ggml_allocr_reset(compute_alloc); + + // compute + struct ggml_cgraph* gf = build_graph(x, timesteps, context, t_emb, y); + + ggml_allocr_alloc_graph(compute_alloc, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + + ggml_backend_graph_compute(backend, gf); + +#ifdef GGML_PERF + ggml_graph_print(gf); +#endif + + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); + } + + void end() { + ggml_allocr_free(compute_alloc); + ggml_backend_buffer_free(compute_buffer); + compute_alloc = NULL; + compute_memory_buffer_size = -1; + } +}; + float ggml_backend_tensor_get_f32(ggml_tensor* tensor) { GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16); float value; @@ -5326,6 +6021,7 @@ class StableDiffusionGGML { ESRGAN esrgan_upscaler; std::string esrgan_path; + ControlNet control; bool upscale_output = false; StableDiffusionGGML() = default; @@ -5627,6 +6323,7 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } + control.load_from_file("models/control_openpose-fp16.safetensors", backend, wtype); return true; } From 0856d7f666be36e481c3ecf5298145e4d6e7f81c Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 29 Dec 2023 14:19:15 -0500 Subject: [PATCH 02/29] add controlnet to pipeline --- assets/control.png | Bin 0 -> 4387 bytes examples/cli/main.cpp | 5 +- stable-diffusion.cpp | 162 +++++++++++++++++++++++++++--------------- stable-diffusion.h | 3 +- 4 files changed, 109 insertions(+), 61 deletions(-) create mode 100644 assets/control.png diff --git a/assets/control.png b/assets/control.png new file mode 100644 index 0000000000000000000000000000000000000000..3ed95d0938c5b87dde60bbef64a9c26cc72fc39a GIT binary patch literal 4387 zcmeHL`&*1@AOAkjOf{p!G_gqylS*PyPN}!jbf9QeD@$ZEwB%F`LOSf+6&04S5o$H9 zLs@O2rGwHC%8FX=s;Jc@Dk_Bzqo#S@dzuMrulEmluj~DxYo6!6zlYD~d%B;yz22T~ z+9M1{0D$%ach|)LBz`1;p^ksL!~X2YKg71hZqAVXfb$Z+&;y+oIRRWs)Re7O!|zOi z`|@o7dZCIx;^{hW7=Uj60#~OcVg4O|IfOSge-XEOvZnjdFY-M){`FIyWQx;@`PJjA z&$my_-y%$-qziE_}-6uL&tF-HRI97o~yeEaFWj| zqJa|~1JJ_`+yQvT1VHrv7yifNF*@9D(4ONYJY{_`FVO|g)y|QoEiKV;tI{AKskbnE z+`;BIMg6Par8UgGUOS#_-3w4StB;@d$`r0fB@zVBiw-%yiH2|EBge|`w*(=EXfz37 z`DeBJ12#jZGeB-coh&?@dUJV|7fo2x2hwmZ0?|YgjxREjU0cEEw(r_moK>CDa5e*4 z(*V|83XvuT-fg>o(f>-c&|Q;=P4j00Uq6i|c^6)k0XEiM>3RgL>}Y#T)wDHv?1$J> zq1f7Oy*UtP&oKvdq4;`3^Ul`!B_X*Q=!QJ0J9|a`R$i$`)8hzcrA9R=H z4q8jKg}&3UvCRYR!jv1VdlLbS4*^Wu+S<--jRknHPU=D_i(H)LZ(V|3U}MX(sC0E@ z^;@0D6=u3%l3ND3{o?^GNF8*Uid&dug2XJ4@=nBj8v7l)=P4V8vv!a*8*go_rmDCE zaGtkwU#a?O*(G`WT+Gw5O_;ezDLtp_-m4&+U+(vF`zHW@eo$9psc;~FPBrD72IvIG zu(xv5G%Fo> zP#rtR2!f=YD>iB|pj@H{sp;DCs8UMLF<2gPXIY4rur*B^zOvfh+F3`o+YLBA|HF?6 zqz<`&n3!3&b?d<;fOsVp#xwwCUUr^4TD?PZOj!zmdk%+AOjMduP_W;L=N&Af37Q-;jQ4iT%lI|L4e*I zHS`KM_oWa`V^TCjS?_|~Q(?&jKsVA;h`WzDSP5rjzw~9 zc`%ge+=DUtsUoKWFqCFbOf^+$ndcU1qg#Oi-pDr1%OQswT~$ z-<++IFY^o{R;nWxzwzu=m{{O0t%gHI_tai33SALq|J;eF&%O0WMqKa&x@$wPHp3S{|?r#B)h>zXybprpcwSn3tZDXM%aI}#8$_1d!g+tnGqpAy@VNa*eVlS+y|o3j})6xIA12v`cn|mhvPpF zyblKvdB%#TL$|J?O*rjrIFmb-Kbhc7BB2*QC4*=b0WIhC)MF%(Bw9>|iqp;p2{BD{ zILm#0xAaa<@iLD5005d=d6C;0K*!^K=o<*} zaPzA-^*cu6J;gKrrU4uN;0(?vGnK@FGe%&foY8FAJRYQBiDzejrBrK*yFyYIu6D56 ztF(dE#|J}}Y!XuV8@7q3JtQ2J5LO{XNut zjKk4RDhG3kZ6 zxDyT>b7?R5I@FnCgc-c_9`Z8!bpkTC%Fi%)J83;!&LW%Hd3T{I1P#{qU>Fy zWCs3yivxc2v&<@TH6#?gXcxU|Bn*617G)*(VE6dGFjQNG&(+0kpM`yBL*c9S;6%t) zeVzZUaxBWUQ3keIIO7EqJXI(^r2K3?Auv=@nNXY+PZN#uf(o}!Y67Pr^kvU38>oH8 zL@z(f3vaY%lsamoTK-#fs~o%_)e9geOl~N7{|U2H1v6K{Ace#7frIZXFQ|_j4o3me zwe^!tMVv$)+0-f(I`NwKTBz=$W8NgWIj>xTL-f&&GO+Xh)y)vKQSVfz>oE6k-xfKF zw_`?s322f?C9|Z)%OJ*=B{a>Ks`OBXLHvx*@sX=Ue-tnRW!P=F}WV&XWrp*v52GLT{9V4%Hjg5fQP2M8OAQg~<%`ASj9@ zEG%q^h*+wi>Ja3Aa^GtiHHE5c?rAD(%+Bbgf<=}-=0`d^#$+v}Xjyf6``pz+TJb+N OV8J|3*KB8g%zpqR6L@j} literal 0 HcmV?d00001 diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 6264d6e2f..fe36a8e9e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -490,6 +490,8 @@ int main(int argc, const char* argv[]) { std::vector results; if (params.mode == TXT2IMG) { + int c = 0; + input_image_buffer = stbi_load("assets/control.png", ¶ms.width, ¶ms.height, &c, 3); results = sd.txt2img(params.prompt, params.negative_prompt, params.cfg_scale, @@ -498,7 +500,8 @@ int main(int argc, const char* argv[]) { params.sample_method, params.sample_steps, params.seed, - params.batch_count); + params.batch_count, + input_image_buffer); } else { results = sd.img2img(input_image_buffer, params.prompt, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 44dbe85ac..811481143 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2699,6 +2699,7 @@ struct UNetModel { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] @@ -2755,12 +2756,20 @@ struct UNetModel { h = middle_block_0.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + if(control.size() > 0) { + h = ggml_add(ctx0, h, control[control.size() - 1]); // middle control + } + int control_offset = control.size() - 2; // output_blocks for (int i = len_mults - 1; i >= 0; i--) { for (int j = 0; j < num_res_blocks + 1; j++) { auto h_skip = hs.back(); hs.pop_back(); + if(control.size() > 0) { + h_skip = ggml_add(ctx0, h_skip, control[control_offset]); // control net condition + control_offset--; + } h = ggml_concat(ctx0, h, h_skip); h = output_res_blocks[i][j].forward(ctx0, h, emb); @@ -2790,6 +2799,7 @@ struct UNetModel { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data @@ -2856,7 +2866,7 @@ struct UNetModel { y_t = y; } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -2866,13 +2876,14 @@ struct UNetModel { void begin(struct ggml_tensor* x, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { if (compute_memory_buffer_size == -1) { // alignment required by the backend compute_alloc = ggml_allocr_new_measure_from_backend(backend); - struct ggml_cgraph* gf = build_graph(x, NULL, context, t_emb, y); + struct ggml_cgraph* gf = build_graph(x, NULL, context, control, t_emb, y); // compute the required memory compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); @@ -2892,12 +2903,13 @@ struct UNetModel { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { ggml_allocr_reset(compute_alloc); // compute - struct ggml_cgraph* gf = build_graph(x, timesteps, context, t_emb, y); + struct ggml_cgraph* gf = build_graph(x, timesteps, context, control, t_emb, y); ggml_allocr_alloc_graph(compute_alloc, gf); @@ -5019,22 +5031,21 @@ struct CNHintBlock { tensors[prefix + "input_hint_block.14.bias"] = conv_final_b; } - ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) { - ggml_tensor* h; - h = ggml_nn_conv_2d(ctx, x, conv_first_w, conv_first_b, 1, 1, 1, 1); + struct ggml_tensor* forward(ggml_context* ctx, struct ggml_tensor* x) { + auto h = ggml_nn_conv_2d(ctx, x, conv_first_w, conv_first_b, 1, 1, 1, 1); h = ggml_silu_inplace(ctx, h); + auto body_h = h; for(int i = 0; i < num_blocks; i++) { // operations.conv_nd(dims, 16, 16, 3, padding=1) - h = ggml_nn_conv_2d(ctx, h, blocks[i].conv_0_w, blocks[i].conv_0_b, 1, 1, 1, 1); - h = ggml_silu_inplace(ctx, h); - + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_0_w, blocks[i].conv_0_b, 1, 1, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); // operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2) - h = ggml_nn_conv_2d(ctx, h, blocks[i].conv_1_w, blocks[i].conv_1_b, 2, 2, 1, 1); - h = ggml_silu_inplace(ctx, h); + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_1_w, blocks[i].conv_1_b, 2, 2, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); } - h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1); + h = ggml_nn_conv_2d(ctx, body_h, conv_final_w, conv_final_b, 1, 1, 1, 1); h = ggml_silu_inplace(ctx, h); return h; } @@ -5435,7 +5446,8 @@ struct ControlNet { tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; } - struct std::vector forward(struct ggml_cgraph* gf, struct ggml_context* ctx0, + struct ggml_tensor* forward(std::vector &outs, + struct ggml_context* ctx0, struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* timesteps, @@ -5458,13 +5470,12 @@ struct ControlNet { auto guided_hint = input_hint_block.forward(ctx0, hint); // input_blocks - std::vector outs; int zero_conv_offset = 0; // input block 0 struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] - h = ggml_add(ctx0, h, guided_hint); - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + // h = ggml_add(ctx0, h, guided_hint); + // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); zero_conv_offset++; outs.push_back(h); @@ -5478,15 +5489,15 @@ struct ControlNet { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] } - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); - zero_conv_offset++; + // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + // zero_conv_offset++; outs.push_back(h); } if (i != len_mults - 1) { ds *= 2; h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); - zero_conv_offset++; + // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + // zero_conv_offset++; outs.push_back(h); } } @@ -5497,17 +5508,17 @@ struct ControlNet { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] - h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + // h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); outs.push_back(h); - ggml_build_forward_expand(gf, h); - return outs; + return h; } - struct ggml_cgraph* build_graph(struct ggml_tensor* x, + std::pair> + build_graph(struct ggml_tensor* x, + struct ggml_tensor* hint, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* t_emb = NULL) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -5525,43 +5536,39 @@ struct ControlNet { // temporal tensors for transfer tensors from cpu to gpu if needed struct ggml_tensor* x_t = NULL; + struct ggml_tensor* hint_t = NULL; struct ggml_tensor* timesteps_t = NULL; struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; - struct ggml_tensor* y_t = NULL; // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { // pass input tensors to gpu memory x_t = ggml_dup_tensor(ctx0, x); context_t = ggml_dup_tensor(ctx0, context); + hint_t = ggml_dup_tensor(ctx0, hint); ggml_allocr_alloc(compute_alloc, x_t); if (timesteps != NULL) { timesteps_t = ggml_dup_tensor(ctx0, timesteps); ggml_allocr_alloc(compute_alloc, timesteps_t); } ggml_allocr_alloc(compute_alloc, context_t); + ggml_allocr_alloc(compute_alloc, hint_t); if (t_emb != NULL) { t_emb_t = ggml_dup_tensor(ctx0, t_emb); ggml_allocr_alloc(compute_alloc, t_emb_t); } - if (y != NULL) { - y_t = ggml_dup_tensor(ctx0, y); - ggml_allocr_alloc(compute_alloc, y_t); - } // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); if (timesteps_t != NULL) { ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } - if (y != NULL) { - ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); - } } } else { // if it's cpu backend just pass the same tensors @@ -5569,25 +5576,30 @@ struct ControlNet { timesteps_t = timesteps; context_t = context; t_emb_t = t_emb; - y_t = y; + hint_t = hint; } - struct std::vector outs = forward(gf, ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); + std::vector outputs; + struct ggml_tensor* out = forward(outputs, ctx0, x_t, hint_t, timesteps_t, context_t, t_emb_t); + ggml_build_forward_expand(gf, out); ggml_free(ctx0); - return gf; + return std::make_pair(gf, outputs); } - void begin(struct ggml_tensor* x, + void begin(ggml_context* draft, std::vector &controls, + struct ggml_tensor* x, + struct ggml_tensor* hint, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* t_emb = NULL) { if (compute_memory_buffer_size == -1) { // alignment required by the backend compute_alloc = ggml_allocr_new_measure_from_backend(backend); - struct ggml_cgraph* gf = build_graph(x, NULL, context, t_emb, y); + auto res = build_graph(x, hint, NULL, context, t_emb); + struct ggml_cgraph* gf = res.first; + std::vector outs = res.second; // compute the required memory compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); @@ -5595,24 +5607,30 @@ struct ControlNet { // recreate the allocator with the required memory ggml_allocr_free(compute_alloc); - LOG_DEBUG("diffusion compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + for(int i = 0; i < outs.size(); i++) { + controls.push_back(ggml_dup_tensor(draft, outs[i])); + } + + LOG_DEBUG("controlnet compute buffer size: %.2f MB -> %zu control output tensors", compute_memory_buffer_size / 1024.0 / 1024.0, res.second.size()); } compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); } - void compute(struct ggml_tensor* work_latent, + void compute(std::vector &controls, int n_threads, struct ggml_tensor* x, + struct ggml_tensor* hint, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* t_emb = NULL) { ggml_allocr_reset(compute_alloc); // compute - struct ggml_cgraph* gf = build_graph(x, timesteps, context, t_emb, y); + auto res = build_graph(x, hint, NULL, context, t_emb); + struct ggml_cgraph* gf = res.first; + std::vector outs = res.second; ggml_allocr_alloc_graph(compute_alloc, gf); @@ -5631,8 +5649,11 @@ struct ControlNet { #ifdef GGML_PERF ggml_graph_print(gf); #endif - - ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); + + // extract the tensors from the graph + for(int i = 0; i < outs.size(); i++) { + ggml_backend_tensor_get_and_sync(backend, outs[i], controls[i]->data, 0, ggml_nbytes(controls[i])); + } } void end() { @@ -6021,7 +6042,7 @@ class StableDiffusionGGML { ESRGAN esrgan_upscaler; std::string esrgan_path; - ControlNet control; + ControlNet control_net; bool upscale_output = false; StableDiffusionGGML() = default; @@ -6323,7 +6344,7 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } - control.load_from_file("models/control_openpose-fp16.safetensors", backend, wtype); + control_net.load_from_file("models/control_openpose-fp16.safetensors", backend, wtype); return true; } @@ -6336,13 +6357,13 @@ class StableDiffusionGGML { struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - diffusion_model.begin(x_t, c, t_emb); + diffusion_model.begin(x_t, c, std::vector(), t_emb); int64_t t0 = ggml_time_ms(); ggml_set_f32(timesteps, 999); set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model.compute(out, n_threads, x_t, NULL, c, t_emb); + diffusion_model.compute(out, n_threads, x_t, NULL, c, std::vector(), t_emb); diffusion_model.end(); double result = 0.f; @@ -6509,6 +6530,7 @@ class StableDiffusionGGML { ggml_tensor* c_vector, ggml_tensor* uc, ggml_tensor* uc_vector, + ggml_tensor* control_hint, float cfg_scale, SampleMethod method, const std::vector& sigmas) { @@ -6521,7 +6543,14 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - diffusion_model.begin(noised_input, c, t_emb, c_vector); + + std::vector controls; + + if(control_hint != NULL) { + control_net.begin(work_ctx, controls, noised_input, control_hint, c, t_emb); + } + + diffusion_model.begin(noised_input, c, controls, t_emb, c_vector); bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; @@ -6571,12 +6600,20 @@ class StableDiffusionGGML { ggml_tensor_scale(noised_input, c_in); // cond - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, t_emb, c_vector); + if(control_hint != NULL) { + LOG_DEBUG("control net start"); + control_net.compute(controls, n_threads, noised_input, control_hint, NULL, c, t_emb); + LOG_DEBUG("control net end"); + } + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, controls, t_emb, c_vector); float* negative_data = NULL; if (has_unconditioned) { // uncond - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, t_emb, uc_vector); + if(control_hint != NULL) { + control_net.compute(controls, n_threads, noised_input, control_hint, NULL, uc, t_emb); + } + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, controls, t_emb, uc_vector); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; @@ -7160,7 +7197,8 @@ std::vector StableDiffusion::txt2img(std::string prompt, SampleMethod sample_method, int sample_steps, int64_t seed, - int batch_count) { + int batch_count, + const uint8_t* control_cond) { std::vector results; // if (width >= 1024 && height >= 1024) { // 1024 x 1024 images // LOG_WARN("Image too large, try a smaller size."); @@ -7182,7 +7220,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, int64_t t1 = ggml_time_ms(); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB + params.mem_size = static_cast(80 * 1024 * 1024); // 10 MB params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; @@ -7225,6 +7263,12 @@ std::vector StableDiffusion::txt2img(std::string prompt, sd->cond_stage_model.destroy(); } + struct ggml_tensor* image_hint = NULL; + if(control_cond != NULL) { + image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_image_to_tensor(control_cond, image_hint); + } + std::vector final_latents; // collect latents to decode int C = 4; int W = width / 8; @@ -7241,7 +7285,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, std::vector sigmas = sd->denoiser->schedule->get_sigmas(sample_steps); - struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, cfg_scale, sample_method, sigmas); + struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -7388,7 +7432,7 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, ggml_tensor_set_f32_randn(noise, sd->rng); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, cfg_scale, sample_method, sigma_sched); + struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, NULL, cfg_scale, sample_method, sigma_sched); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index 3ae012f96..03d1eff0b 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -63,7 +63,8 @@ class StableDiffusion { SampleMethod sample_method, int sample_steps, int64_t seed, - int batch_count); + int batch_count, + const uint8_t* control_cond); std::vector img2img( const uint8_t* init_img_data, From 1d4262249af9b2c8baac0d1f43dcb97fd1df2257 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 29 Dec 2023 14:25:53 -0500 Subject: [PATCH 03/29] decomment controlnet ops --- stable-diffusion.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 811481143..0b5a802ff 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -5474,8 +5474,8 @@ struct ControlNet { // input block 0 struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] - // h = ggml_add(ctx0, h, guided_hint); - // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + h = ggml_add(ctx0, h, guided_hint); + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); zero_conv_offset++; outs.push_back(h); @@ -5489,15 +5489,17 @@ struct ControlNet { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] } - // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); - // zero_conv_offset++; + + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + zero_conv_offset++; outs.push_back(h); } if (i != len_mults - 1) { ds *= 2; h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] - // h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); - // zero_conv_offset++; + h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + + zero_conv_offset++; outs.push_back(h); } } @@ -5508,7 +5510,7 @@ struct ControlNet { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] - // h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); outs.push_back(h); return h; } From 285f6c6c0c4db1ae2689c702a87c3b7f5c050572 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 29 Dec 2023 16:40:13 -0500 Subject: [PATCH 04/29] fix invalid variable (overflow index) --- stable-diffusion.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0b5a802ff..326d6a3e4 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -5330,7 +5330,7 @@ struct ControlNet { ggml_allocr_free(alloc); } - bool load_from_file(const std::string& file_path, ggml_backend_t backend, ggml_type wtype) { + bool load_from_file(const std::string& file_path, ggml_backend_t backend, ggml_type wtype__) { LOG_INFO("loading control net from '%s'", file_path.c_str()); std::map control_tensors; @@ -5341,7 +5341,7 @@ struct ControlNet { return false; } - if (!init(backend, wtype)) { + if (!init(backend, wtype__)) { return false; } @@ -6346,7 +6346,7 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } - control_net.load_from_file("models/control_openpose-fp16.safetensors", backend, wtype); + control_net.load_from_file("models/control_openpose-fp16.safetensors", backend, GGML_TYPE_F16 /* just f16 controlnet models */); return true; } @@ -6603,9 +6603,7 @@ class StableDiffusionGGML { // cond if(control_hint != NULL) { - LOG_DEBUG("control net start"); control_net.compute(controls, n_threads, noised_input, control_hint, NULL, c, t_emb); - LOG_DEBUG("control net end"); } diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, controls, t_emb, c_vector); From d08d48f4861815666524bc34bf02a556e68ccd3a Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 31 Dec 2023 12:51:49 -0500 Subject: [PATCH 05/29] fix NaNs + fix implementation --- stable-diffusion.cpp | 133 ++++++++++++++++++++++--------------------- 1 file changed, 67 insertions(+), 66 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 326d6a3e4..0f71d3d1f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -4983,8 +4983,8 @@ struct CNHintBlock { }; hint_block blocks[3]; - ggml_tensor* conv_final_w; // [feat_channels[3], model_channels, 3, 3] - ggml_tensor* conv_final_b; // [feat_channels[3]] + ggml_tensor* conv_final_w; // [model_channels, feat_channels[3], 3, 3] + ggml_tensor* conv_final_b; // [model_channels] size_t calculate_mem_size() { size_t mem_size = feat_channels[0] * hint_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w @@ -5103,8 +5103,8 @@ struct ControlNet { struct ggml_tensor* middle_block_out_b; // [middle_out_channel, ] struct ggml_context* ctx; - ggml_backend_buffer_t params_buffer; - ggml_backend_buffer_t compute_buffer; // for compute + ggml_backend_buffer_t params_buffer = NULL; + ggml_backend_buffer_t compute_buffer = NULL; // for compute struct ggml_allocr* compute_alloc = NULL; size_t compute_memory_buffer_size = -1; @@ -5112,6 +5112,10 @@ struct ControlNet { ggml_type wtype; ggml_backend_t backend = NULL; + ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory + ggml_context* control_ctx = NULL; + std::vector controls; // (12 input block outputs, 1 middle block output) SD 1.5 + ControlNet() { // input_blocks std::vector input_block_chans; @@ -5446,18 +5450,17 @@ struct ControlNet { tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; } - struct ggml_tensor* forward(std::vector &outs, - struct ggml_context* ctx0, - struct ggml_tensor* x, - struct ggml_tensor* hint, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL) { + void forward(struct ggml_cgraph* gf, + struct ggml_context* ctx0, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { // x: [N, in_channels, h, w] // timesteps: [N, ] // t_emb: [N, model_channels] // context: [N, max_position, hidden_size]([N, 77, 768]) - // y: [adm_in_channels] if (t_emb == NULL && timesteps != NULL) { t_emb = new_timestep_embedding(ctx0, compute_alloc, timesteps, model_channels); // [N, model_channels] } @@ -5475,9 +5478,10 @@ struct ControlNet { // input block 0 struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] h = ggml_add(ctx0, h, guided_hint); - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + + auto h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); zero_conv_offset++; - outs.push_back(h); // input block 1-11 int len_mults = channel_mult.size(); @@ -5489,18 +5493,16 @@ struct ControlNet { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] } - - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); zero_conv_offset++; - outs.push_back(h); } if (i != len_mults - 1) { ds *= 2; h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] - h = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); - + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); zero_conv_offset++; - outs.push_back(h); } } // [N, 4*model_channels, h/8, w/8] @@ -5510,13 +5512,11 @@ struct ControlNet { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] - h = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); - outs.push_back(h); - return h; + h_c = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); } - std::pair> - build_graph(struct ggml_tensor* x, + struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -5581,58 +5581,65 @@ struct ControlNet { hint_t = hint; } - std::vector outputs; - struct ggml_tensor* out = forward(outputs, ctx0, x_t, hint_t, timesteps_t, context_t, t_emb_t); + forward(gf, ctx0, x_t, hint_t, timesteps_t, context_t, t_emb_t); - ggml_build_forward_expand(gf, out); ggml_free(ctx0); - return std::make_pair(gf, outputs); + return gf; } - void begin(ggml_context* draft, std::vector &controls, - struct ggml_tensor* x, + void begin(struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* context, struct ggml_tensor* t_emb = NULL) { - if (compute_memory_buffer_size == -1) { - // alignment required by the backend - compute_alloc = ggml_allocr_new_measure_from_backend(backend); + struct ggml_init_params params; + params.mem_size = static_cast(13 * ggml_tensor_overhead()) + 256; + params.mem_buffer = NULL; + params.no_alloc = true; + control_ctx = ggml_init(params); - auto res = build_graph(x, hint, NULL, context, t_emb); - struct ggml_cgraph* gf = res.first; - std::vector outs = res.second; + size_t control_buffer_size = 0; + int w = x->ne[0], h = x->ne[1], steps = 0; - // compute the required memory - compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); + for(int i = 0; i < (num_zero_convs + 1); i++) { + bool last = i == num_zero_convs; + int c = last ? middle_out_channel : zero_convs[i].channels; + if(!last && steps == 3) { + w /= 2; h /= 2; steps = 0; + } + controls.push_back(ggml_new_tensor_4d(control_ctx, GGML_TYPE_F32, w, h, c, 1)); + control_buffer_size += ggml_nbytes(controls[i]); + steps++; + } - // recreate the allocator with the required memory - ggml_allocr_free(compute_alloc); + control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, backend); - for(int i = 0; i < outs.size(); i++) { - controls.push_back(ggml_dup_tensor(draft, outs[i])); - } + // alignment required by the backend + compute_alloc = ggml_allocr_new_measure_from_backend(backend); - LOG_DEBUG("controlnet compute buffer size: %.2f MB -> %zu control output tensors", compute_memory_buffer_size / 1024.0 / 1024.0, res.second.size()); - } + struct ggml_cgraph* gf = build_graph(x, hint, NULL, context, t_emb); + + // compute the required memory + compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(compute_alloc); + + LOG_DEBUG("controlnet compute buffer size: %.2f MB (external buffer: %.2f MB)", (compute_memory_buffer_size + control_buffer_size) / 1024.0 / 1024.0, control_buffer_size / 1024.0 / 1024.0); compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); } - void compute(std::vector &controls, - int n_threads, + void compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* hint, - struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* t_emb = NULL) { ggml_allocr_reset(compute_alloc); // compute - auto res = build_graph(x, hint, NULL, context, t_emb); - struct ggml_cgraph* gf = res.first; - std::vector outs = res.second; + struct ggml_cgraph* gf = build_graph(x, hint, NULL, context, t_emb); ggml_allocr_alloc_graph(compute_alloc, gf); @@ -5651,11 +5658,6 @@ struct ControlNet { #ifdef GGML_PERF ggml_graph_print(gf); #endif - - // extract the tensors from the graph - for(int i = 0; i < outs.size(); i++) { - ggml_backend_tensor_get_and_sync(backend, outs[i], controls[i]->data, 0, ggml_nbytes(controls[i])); - } } void end() { @@ -6545,14 +6547,12 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - - std::vector controls; if(control_hint != NULL) { - control_net.begin(work_ctx, controls, noised_input, control_hint, c, t_emb); + control_net.begin(noised_input, control_hint, c, t_emb); } - diffusion_model.begin(noised_input, c, controls, t_emb, c_vector); + diffusion_model.begin(noised_input, c, control_net.controls, t_emb, c_vector); bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; @@ -6603,17 +6603,18 @@ class StableDiffusionGGML { // cond if(control_hint != NULL) { - control_net.compute(controls, n_threads, noised_input, control_hint, NULL, c, t_emb); + control_net.compute(n_threads, noised_input, control_hint, c, t_emb); } - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, controls, t_emb, c_vector); + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, t_emb, c_vector); float* negative_data = NULL; if (has_unconditioned) { // uncond if(control_hint != NULL) { - control_net.compute(controls, n_threads, noised_input, control_hint, NULL, uc, t_emb); + control_net.compute(n_threads, noised_input, control_hint, uc, t_emb); } - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, controls, t_emb, uc_vector); + + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, t_emb, uc_vector); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; @@ -7220,7 +7221,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, int64_t t1 = ggml_time_ms(); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); struct ggml_init_params params; - params.mem_size = static_cast(80 * 1024 * 1024); // 10 MB + params.mem_size = static_cast(35 * 1024 * 1024); // 10 MB params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; From 194df3abef9f727f4e3699d39f6dafdfbfdec232 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 31 Dec 2023 14:59:24 -0500 Subject: [PATCH 06/29] add cli params --- assets/control_2.png | Bin 0 -> 6237 bytes assets/control_3.png | Bin 0 -> 18856 bytes examples/cli/main.cpp | 28 +++++++++++++++++++++++++--- stable-diffusion.cpp | 13 ++++++++++--- stable-diffusion.h | 1 + 5 files changed, 36 insertions(+), 6 deletions(-) create mode 100644 assets/control_2.png create mode 100644 assets/control_3.png diff --git a/assets/control_2.png b/assets/control_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9352dc0f4490ba4ea682e910e1e8e092715ba6dd GIT binary patch literal 6237 zcmeHMd011|wqNHYKtLcMDl-I715ThtKm>ABlu3mvf>=NTSgkDz6ag7?P$X4UXb~sW z&{olk0(AmGNRTSjs#p=hDnp7&nIlsIBsp)Rz2Cd35@g36Y06ojR6(v>jis@zzh*yV~Yh%lomMJYwXCrKqf%h&gX zMm-%YmfmUX$$2@Hz^@z3>$0W6{mJa!>RO2!`7E9e;8z(4K?)i{4}c4`(J~q_Ygcb# z2%Q1PA1w#a$ps)uL`I(rz+z!YEp-5fCjdbpmGPPXPx60c7xFjldLemtAjg)PPdl(z zOR4C|(PcR%0J&;1P?11m`6d7>vcr@bQzT@}mW{kGt?vpO+Ey&Rak49Hx38&CHHr_U z=Ed3$2M@-DT0yFb@|Sw>th6kgugXoK!}g*tWi3UA$z+M+(KhT)D^;#>2n`IA`VAiK zn$=Ru?`mo4GYhTbdG~Bzos9Ker-5;c$z<3q)$Z=2vbMqc7aCJz74OV?f0qpXu3~Uw zawx(BURVIPt};;mn1SRO<8;^MjZ2~{=irfnN*nsK0iJ)!n}e|ER~u-(Tq5S>_*_zb zuxF}8s`OoP^V}TrGj(V!^r;g@=ks=rh4odg>uirx#<`2arH^||Pv!VS*-~^+&t8$R z&57*u$0KQ-KTbBr-$|TJRZM;FoGh#i7f#hU<3-!mU|B-rZr(07_n}yDp^Xeqq_^{V z!+zDTw8%RQXn1|9SpMq4)s|8%swibO1ZAPKEp|;&vC`()y~-ekx6nSK;K6GujH$HE z;VB}dO+B(_6=t8vr$U@vC9z{+Q~UCM@W<>HV_+QHnwDsiG!)$76~?&9=QZ8Z`TUr8 z2h0uCwYx8g?mL{(sN3lbr}BW*@~q?RQX`%nBaxB`Hh7k*$m9U^txxx)3as}9k?M0S zHw4*onY`e4i>D?^jwzDa{3v*i&tNKxN*H$Xa5EeNUzjLw?V!7uPY}BCxQS+;S7prmJFfqPH)7v;70~@Vs8f|BfuGU8 zDDHtMfBQRK7c8M?dj2Qq{5*yY?iNs3FRto^LWHHCKI<&y^F|B`=T3F?#kkRzXiX`~ ztJ>n?U2g5(eOIdRcXow&y6{wD1KX}he>iDgd_e0IZ&u=363A~Ea5=I&A-&9oa&8U(8z3Fm2;`37g|j)!byLN!o#=u z1+Q{GOH1iCPi@twV;seV%vV<4Y-@56Jh!Fuk_+-<$O(10SK%PbOO|z9Ic~jg2KlD+ z7!^W|rYm`epkf_Mq@aVvqNVcEU*>6_>6ZJ`<)b-OZ6~SWn<;xzdAD`2&`eB|4c5Qs z)(RfE%K5FESoT((C#fVQ)h|>g-SG#1JypqRC=6PmMFl^at-0%wdCiZEG~R zuda$a+Wq8dt>BHb{A8{r+NwNnXaf)0RdVxcSh9CA_8V!BuebFPD74O^GN1KR!%5O? zVf#S3R9N37I97T|nvGQ*pGmh-ebyPRhaYzbKg~K`L;#$IgcSzV-6pTHg=-AVpzlQs%5!VcH{KIcMWR2T1;LjTg1~>T>m3HTqVo<*- zbYv!`OQ_RftICA(=5Vg%np2hJ=xf*f#7+4<#Zqpt^`5`eg<}`3`Xi^7nl~Z{opFm?Ml}+rS zL&4s@6GAc5?g()G&klthTI-r=1+ALF%Baw^aMp1G%G9$sqU&@hu8jYcvsBK0R(_>= zlmYtq>T$GqXQ#wi2~u6Sdhd!kHxwuD8!C+>*yi_U!lH+gM)%h_;1qj{gMedB+^>?v z8s0jo_6oH31s!hq9H3r^nIaeTc>OK2Yy(cdG*F)YLO}5}(pM4LeKhNXz@3vtJur;W zIS<@T<>!ZdHF@M)2F#CZ>~xgt+rRDYI0w0lG11(slQYD7IhhtvQ6OgKnq#8SJ%-BL z@-5rz%`P)w&$p`7wK{WhOdZzUcpGNjO7ucN|QNUXr(EKBn zCob7drI^^w@_?9paTrvM64}scS!Z6j9w)a&@`WRc)7Pt*7V6;KwoLxg`V~c}VIp>9 zmN0YD#AjAbB8RHcapsJ`l5Puk6*Zn)LzTK9{hqZvU7u zgHFV(L0fs>MNw-$$600HAwh>3WMPRGuUPr$@lOigzzDq7U~JWeS`+QMEDn682dhvy zu&t`4ZScJT^fd2iX6o>P(P=MRX>TR!WIj+Kk;NEKu9&mtV}2FB#S}lMv}rk3(zYO5 zz_ug7`gMAk>FB;611TH4SN~u^+*C(2t04j%kF+W&yRDi`Vc(M-kuCUpTUK@lS#y zYX{Xd3fNGQtpg2tz@6#O2b#3tREeLhgNXvY!eQz*l_Bo&Z%#PRAShe zs6SKWJ%92%?K2d)>lnG&`tRKVh4TheTpT49BsDPozN;-#hrmcj;JV^uz=mpL7JY$U zns3He>fo@70gHT~%*LNj#Aa&Q>0+Y!8)oAciM06|)`d|xxuIYg9rm%I;odBZT*V&p zYi#~pXgI7>5wX8*CQB6~P84ZigLZG!Itw%l=0U^n)fTyrQU*R0}eYn&^OWL(pX7L9jCg~2;b-npXfm5 zKlq|s2D0Awa%D_<%8i~jQa9sGnZ6oEw zt{JzfjFk-NbW`DIHZ&`buLsfEwqkkR*#ak00HGFGZXMXmBC(~n7w)*$45?FphF+vX{@o;x{06?EStye)v`7;HkU z8KgIiZiQ$X_0VJO8U%$$eyf@+T`}A5?5Nf55PGl2dKrQGCAQ$KIf`dt}zbVM<2(a^~(#m~B6w-kfO=S-USmdq_oTIK2Hd1aPp zr9a2ND#=&2R<9p#RjTc<1TM*IHRmT8DSx1+Mm);hi9rDiIuAEV1ZNQ2&u~~$881uO z7FBvypvq7|)y$VN3CBgCrkqIoUL5V1cH^~~Ckl=A7mkr$21?N*6eUJC#yDR|f?aWA z78Ew+e&;4m;CX4>(GxjvP?ZytXQ(DbM`Ji7k7-Q|hIMeXoOW{q)T1(iG!t$dkxL3j z`YuHON=Y2#jYE+o<`9J1Br{>V!q0SGS{0U7r+mODQ!p2~1Ho_V72c*Qow>`*ROtvd zo{GDAg^jlY();>$n3Y+e8ZE0-Ikj=d^OCgc`{SH476yy!0;!NU{#f`39qz7JqF^2s zbN+nIlP;hn{1>IKlxhy^K-3eGc56sTHBT(}o_4lCmFFsF@|>!%sBeK)@#0|et>1KW z`xB=GTlo2L^ig>S1bqYKp=RAOYC=5f?8RhMj$R-xJ<2$!4tUk=Wx}+#N(b4iESDt) zNO4kBY1^5SIM5fecVHthY5N1z;neH=fbps$RH3F%weRJ)uZUoWdv)5IZv7yr}Cv2;Nwsj`NWZTP{7OqH5OpmRy zH7llZZE$!nY1eOjz=wV$uY~4#Kxn4yFnMCj)x{ymBd%4rdFT*>T#FE~Yj?gN@LJE< zr_~D*FmgRwiKKSNKMh<<47E%YZhLzvpO@^qx%5mZoH_?^LaL~7{`^O)!l6FVb_91H z45_xv5<2JXYHguWZ|g_0V7+H`{n}z>r7MNracxj7a`B}C6Cr>l&NCyG@?Mj_{&lQv7f>YHnLJ^qAEoF z`#8!j_Z7|N(8u1dYy+yNwxZxa6O3hOGJw=Jel6Q-Fhsl#o7O3QeyFEz_zNVP(!Zc@#N5qHVhj)LBz_O+r)7KLbjwJad zpD%yjLfz5H{imjP4J8b&KE8>8K~%^79GY8dHK+Q?v#ET4TCNQ;zt|ObuFxj&fA4D7 z+^x_fuOrI_SLozUUtY~kU&a$&7uk2K2;%t(4mga@k)k_h`wm?Nw#imBfbQ z08UA2K4GsKEGkG#5OsFM?0#X`0Chw*~J&jb69uLaQjs^>Ye zfu76!n0;jSnd0OsIKBp8Z!72Gs%^}Xckh?4vBX}WQH0U6P7ae=QXfWbjc1(J1Bg~F z8BVbh^$Jpvrc0HZt1mH}AQ z?&y|tRv`Wd?6Y{u9aL|s=4mvdFdgSayXhAW8DcbnDIg>QH>bN=)JToJQQTDwW&j9fMf)UW!y+qEGH?Pr&lycvKGixKmD7vV7YmX@Ir`>S_5QcxX8 z>>1nnF7L{k586#@KgplmCFtf#{HMdM_8_CRY9sm8Uy+cMlv<+w?yKGT(fZ^vr20QX zXoMk(;>E>o+WQ)M5zEhzC!({i4i11HO62Kr3hCH#TY61t!$*$1V=r$)&6rI#kTx4*B~kE_nfMb8 z{8x_J$9I&*t?iILmg)^!AP`9j9;5lkYCzFwcz&4gwHHZWUc%QI4A5`@b7qA|N12hk%qaK#2iLhav{u-6%+H#1v7G5~P(;9|a^vBRLe5(cLNC zBL)Ko+i&0BAA8+ zdQQH8{GaMP1j_rlH^~o_t`8nP0D#Iwh6{Tt@-y8_U2}f`V150+A86i%FBkyypn94Q zo(J3PW?QF;)O#aDR>!(zCB5|yEESf<;?|)hF&9WC;cye#C->&Y4EBOa!s zn@LQ2ZCV}F(D1dMbQx^K`QifUa$$i~b59tkkW%3qeG?na1e@(gp^4|jLTu7I4^Cac z&VYDqOFI4#ak0#YjffO|$@1z!;^o;$4iQHzWG3PxNcc-5&j(<@QsDiCFX_aWEOg@~ z%fUu*rzz+O|4{I91aUoLo`g$7IKQU^>fYECU)jtXhQw&C5bDGPuQA_p6(uwxyaVF(B_yWRUS zo!52vA4Dg+=Z@*8tv7EdfvT6%W!9;(Jp*PvF;n^>Z)bh6vbXmfxOoqIi#k!5{iOm) z2uK*Xno#hs_d4%KFV>*P&ZDaTr2q2!)nVkepG-?=lETk73-Rbbw=@p5z`%%NdXG!m znP3`LFjf5(r(=c6YWd_3@$G}rh^az}=0tOCAM_`n?PWhTi!(+vzY>N+K<%&Os+nY> z)`p?pxxL4^;5Jl`5C0@gsZKE;2#x;dK9MT-Kz)uB4 zn8MBBWp4{f8}<{2nWpHHwd7i@NDkoVzbVvBJvlav2~n+ji=_r> z=H>dW6`j6^pZ%OVEhh{tKCH`qK|Li%{bPp?QStQ2R{Tb8on~_~;bf3t+_ipN4}bu6 zw=@<^&x1M*zET<$U)u5J@_O%cLT#v+t;GYOZ(p1s>4A*sz{l42k;hWA*=O+>dUnLa zS1H02c(>KQ5fY8VQGvJBYl|Z!P#Gh3Lf&b&1gT11B;=dx|27tv;(XKHLo+>3(wE}c z<-O*zr**OtPYH|$Q;`Z6pD+@-A876^TZUNRJyJRznTZsQP?O zF#<*VTT`j6)(!nBB=p73K_lUw#7isumhM%_iN|GN?^IXz?E3D;Q^b^0o$+m8FivM7 zLVG}3LMp-UB%LaR|4P8Irn;v_$z=Qo=9p`ZaxW0Hu*YEFkr^QXtkZG+Z!9gdPv3b@ zLN-0VtaL-8s{1n~VDpr_vd+oeY_3?xE|oI3A|F27E!K32X|W8`709=7G6ZLU5SIU` zh?s5;SVRw_EHlS?)+Fqbj9m!g;gk*!r~uVhJ2EgNR;Ta0dU^z1F|C3y(awaEcrS?% zmdlt8_c}b7yeHKTHu|g|44~CooH*aSuu9R=Hpl0t+i1@vdU*pDz`l;V@wNI7dHM^< ziwikJzes2j0UdFg9&whgwqjXsQ~G4;>dhIcQ~aU$!_mNK6O-z34Jx2f$!xSlZF;yM zwc0C3e^i3!Y&f^JvQ+95ixSrh(V8k`719@dO=0azBum-0Ul`p9lu+VVix(;?Y(Y9IX-i?=sg`VTDlfjEY+A)%;(cE5qAPv{u(hF#Nf7i4+a$m zZVtom>-0;u+eAOE{rYTT3lnFrBQ*{0b5k>4FdqEO=;uUj2iYK?qCTkfOdP;e!gM#Q zPc}QB@+K1Bc_kTW->ga_YSGl5bB3Vh2(nMndVs=%Yu!D+7=q|;Z$0N;h~srB`!_o` zGl!N@vpqU%8Mh9vcg@`2@sfPAt-A|+QFy7z`|0zg)!mFadphF!v4@((R;&B2`zTBo zP-xpW+(X@SP3|ED5qZ8*D&(sK2QdfDrmpW1!_LmK2Cy&Ak3Rj&lsLk}0!nt;C8pri ziDGUOuU&1EDo@i&jy(co{H~8|XvFXSf(w`Aya*EGF&oXT=uJE!~ z_1llX{~jCgEEmjC=J|PkuXbY2wZ9`jmaD-+XE&g8Y@oGKaO>f}t#HqBmM0)H?jAOi zf912&Pv^IWCWg}{7UM-n>QgGQV!H)GCTCjywifW%+2Nn<4iQG%V|(_aPrlcFW7!q} znO$S1AzrI-z3H9^)H2j|nWa%md}J7z)=XUad5?ool9cWuBlMxFNRO!#;M6?#7+*Zy zM&TD0`0|YO_9vd*^jcZhWX@-Mz$b#))tc#B5Gf93OSaK z+g4Fs&{(Yzsm-p$WlD)Gu=Kumg zR<3Ao9q@BDCNx0@%U?j&bI+`YK8P2HfO-Mp&H}Q{NO`VWZNpeh*o=k1 z+NKvWje~Rs6x~k5exN^+E#H4B>6*dK70%aem6~TC5;XI*#-D5)I3Rg(cV5v^vrNc5 zaC8#P9Ybd$iBf@lyJQzvbNmSq7WO=A`tYOcvlgne$d=C{gRIA}!Tlwz&9Nkx=C??OCt7~!&l z7?1JngI=^&tI6fa#opK&0Gqx80e&UPDi6AZAArESuiwPoz87Dq+dm)W;nx|aYwMT3 zH9hQsT22%uy$LKUaQgR0diz~atMlj`Is;Uh)vyqJ4BK)ZBtLQHFMs(n(`wGBZ|rl) z!g^SIn?_L2v3>+2U!x0aK(;BmdhMfO3Gw7D7eY_Bq1QsSY`4L^$KLeE z(78Ee%mAq`XCw#3HUzo*Jx5RZ^GsjX%|@n_dgfR_cSU85C%A?Zpl(xG8_ztdJJ*d} zSeivnV~<#)uJc^TyG5XBrqw66kvcnnq8HI%hX{E(b?@Ytw_y8<1=fXuyaiP8k6mX- z=AQsYO<3Z1YcX!A_AmX>^&tXZ6yCyRH-2)dWWl?3Legy_Mn=EC-jI%ZaSY;}r(=ZO zTDP@^44r&*#;P=&^UO~aA&6lU0)eMr|w#M*=NfRjp)JR=wV5x+qNWf0-&IRl-4<}H%+l*LT+<}qzQd_+HQJ2Q4_37_r!~)5Fysz6meESq%n@x!$_RKR<9G5_DsRpAyHBn{_E@ zkI@HJ5q-~wC0|tUd`MBMHBu%kCA+nQ9`t{d*1}N(+tj*c>XGEgp`0L{NK3Zb8}I&b zne*6VDb)ur5|ZBNi^1KP==~OckF@Zu$}7>v%h|{Jj4jrtA!EAP$^Q1wy*?{ZIs}RA zzLf3?(k|Rx3i6E5kqKz>NDWS3$kETK*1wb$lJ)3U4>!UFbo0V%V^{oJ!!OT?_2F?T zZ`flQT~Ypc%C->&dSM`YkXn~nK<6>g{1p?DsD1xDI!<=M_}BA&tC8scL_Ra{{JRMo zf?XzZS+ewLTYG&MeiVtx(xXo@m=FsQ4cHMKyL!Uo$T>-6lN$JsH z4z9V@CWl%|^^BHVtcWXa+zGYVh`(Dn{RP6uA*jzOuWwjGd$OTLQD&3dFV+hrW=U1b zkNcY^iXX%eqD`z@*^qjlBw~5p$hTU`J8ORR$dr3s+NJLA9%fOs%=?p=r!Z*WmW>Vp zu@D{6V1#n>cH|aLX+Pq14qq|Ud+xDc_kmZ#99mJTJbc5IriZuJhAb;IE$fIP-k@vC z!tHdP+)l;k_kJE-Y$<0f9}yg&4I>E&13VUggaM_5hZQHf@E8Q8-J`+u`O~)UxEC1uAv$%Ox|}8$%2tAOr_& z!=SD;f*R?<22-;}jpEg+9KifFIm#O)>T^8APt{dK_yru($b zWCeE!d;*~0c8K1L4|OhXr$g68xs~^?LKBX*f&X-%@HBtvi=^2Km*Ocy%i6c9f9}REx?{loa^?qC&a4v5A6(|SNtv6Xmw$4v#lR-%g-T5kEF*TjF2#{)unGHqRIL zEhB5SYo^Ca3icV}T%}~bSSsX8&UY{Xi zmLl>6{OS8`n@tQG?|HQb>&um&3njv{%EG|I_W>Tf^eW@*vLLb5#!Tm*%b6pls~SCQ zbL>Wf2V2J)KILl?{j0OeEysg7rK^t23ETa%u%4^eQD%2RSF=6FL=pp{^nm>9FJ6^O z_4g@*f^SqD60Iw)taEUUEgtj}# zft=2ZIVnFJkL{IG+}#;6)CJ~svw_|_Sxq*rG{`q+U6sjzx?I)$$$3n>#@ZQ>c$Uoa z326RTS?<0-bw~Eppx8i~?#JrknpJr1;4%@&m2s&5GvTBeT4lvG16E!VC=_Q0`r0%k zIt7Z;!ds&RuNoNbkG*`ch7>BLV@F`Sd#q75}9O@Gb# zEp!fy;CK@uEK(%nm7i9YEcbh`XljJ=A5PR(Q5!l8=GE&gK3yOs5`y zRQ|x}2`yQ4jZ0msrJ5ldf}mib7vRMMTu$*< zim-;;Q26Hiy3&^RuV96ZcH_=|*s`&oOb)i}Gt0CpIj;Io4=DYWVa4?m z&5!3{#yua~1+wab_us#`qQQ!o_M70ror8d#l)$$-5`ivHL_nu)RdoV%h}BKo5ef=K zNOP5G^4XagU-!FAW|xv%z>NmUFTkKDWk&2@qD^XMVyBTiq*vo)UWxCHL`8WPKr1vc z4h0QTlMKNW(U=!6%jhEnO&F&bHDflf{v7hQRkQ&NjAn|5;%0eG02c zbdJm!3EN$_itZYSclnEbcdztChH}_P&G+!H4>0uV-@TX?V@kqLX;y*tARpmnwYa^P#SGc`PN6%NF!)Y?uN&MygAD?fQ0ipndo}Q{Wye;-2L8Ru*ve z!k#nXSMzub4ByfNb&@JPDo?!bc2Rd{I_NVHnrnTN4rpWW=pLRvEY?~Wz&k_7=+cgDc?$8(=>={ZItprh1#% zyOCX=(h}?Kw2NJMGuNqcw0|p(u}T==duIU1(WmICS6oT^C(3L?)i#AM{C8TF1Hqjw zS37_|i4Sq70`EeZfi_(i8q>4Or%EjnIW7#ur}Sz35ka8f!SmJ(uRyuC(sV$J3g+Q8 z>r`Zkm7iKCyxYZKxQB0w=Dlh(cR?g_YI7Kjd7=uwTGKfF<>|wLQEXv4n&ZaI_1$)z z)NY-TM(JG^3s%4kDbb$Gy`iD^Tx=IbL)WGvoaFO}DDUa-3wo0FMl7YZk&JX9<+x^kA zY7q&zWivdcD;=DPc$E1@?1%-h=^YeZyW;guk?K>i9K~IMCSM(Hu^7NA=}%()yla{n zc{cGRTPex-jx zMz&`ulB(ATy-`g;xp2`W)N5lni6ess@w=N&{T5V|Z07551}J~N1ix@49)2uJO|bFW z8wn`St8V-E6dw-09$%FDN{9jvImoiXZ^V2`9D4;r5y}ZyW!q7w(Xo*hcSpKjle>B`p?xv`8J$A<&P0Sf z@Ktk5wT-4nd5$>mL!++Xak6U*6q z^T)LAP7kWCT;|#E45VR`ER$MxY@$mB;cYa)`%rTS^#Mx$j;s7+;C^}UlB0rGs#?YV zHc&$rIKP`uPi}j%mpbsDYk-g;IJbeR?4c`*3V`5jVM9!X*9L@{FR2d%>obt0`Nep# zr$2=>!Ris@!hm-2Y6~aS)`~NqDz}B|k^4Ky08k-R1GKuHKqrPJzWYpo`A;d@eD+&E zv=pTVO;aN9XB5Ebkdi=760+scYqVQg6P&yGr}wm#TqOAE5@?kD;dkILFU5Z(X5fhq zmKb3lXw+hg3RP&C_x~gd8uWVqNox`Ske`6r&mz(|U7x4701M}SK{o&}(gpUqSUF+B zz#arpBV8O3avXj&k!uxpoE|Q+vl0OM@??e?eR0@BSlDuJh)=Cn0RZ0`?-3Ih!jw(R zAuTQ%zG(qu4RBP^asjzC77J4kDb;2g`z0&$bgY@a-h&0QM zBq|V4dW9X~DO}+a@&SRL{@;po=*&kPCX$N>m%f*n41m_$;T?YZyC_hGZguqo0TqH- zZFvH3(PZ1+J2L%!0N6j*1^~ptPqG9y(kHAiq9-K0pGW}!R?G9?>&D^5XFSb#KETe3 zWE=fN-}cNJnD_L51W3w!achZZolX3BS*UionMg_{kO!aB$fQ0%8;AOCXH}EqknL$C z?|W-DAU5CKcX`s!4haB@kOEeO?t>J{MM%)LK@Pc1oBSaH_rH1s`dB2HS(L2T966RE zPvpe=YESX47$0F#8<O1{4{bgr82 zR?~`hVSmU?vC^vwAk+$6fOl?Tuh#(8pO#-d$#LWrG4f=u^uW#KG>?|a_II|-8^(aZ zZqWeyu1@=%iHUza(S@dftg|%$m=8X8tkuLVmvrW*9nY1;3#gOZGG!uxxZbBAAQ0jB z@=2s8cyL*=D0J-r3?N)LNRad=8qua8+qs8|=^^2XwX~_4Ov7dZ0?hGoBEvS&pMtv_ODR;L9P;o7%=hoi!|0epg@5LXpe z;D`~Cvg}G3VLy(g0-_l&eTyDggMMxgCEj}W_j}$B#Ge@${EfM-{#?zFYB9cPO7R7l z-x*)Xd#Ou~(GtAls|MCl^-e?(fV@7HB?7vh*g2iWiZE{>iptXM#IO!p;!W51DP07B z#onU-^)lN?H1v0FXG0iwk3RfMwu#pqFO~XaY+El5&`sfGB0wk^WFnkbcckReAt=+P z$nZ2PPL`^f#hCO1E7)A%ctKr~WcT3k8339{{ z9Y!*g@rtgA=w$$Q{_N2q7FooE@{@K$yk!}#&_CB72(p5T@R1&^ zB?pXB^3T9gpj;-h_m?AF4S`0L(;;4j(BARYd<4AP(6_;@m&gA$q-Z0f#lGZN^i}P5 zX*%^Z)^1u|VEX>nqunMJ{a9Vhf~Nhk`u810%T-QKS|ReBrW`3GZn()$7!)AM;UhR6t0x!@h>0q(rt$1=)M%75Q;=s zSb+I8Z9flImH(hrFS*a9;7JX|SEuxQ_N=XPZ|tx(T!e22@r~xY>p}YjxYwC1a;2z+ z1VE>Zbn2ayh*lC*=%iX9S0`!6ZT&ff^X2RpUTkZ0a;Vmdp3IaC4~w_gEkEBfW*%W6 z)oHSbOSi%7*>~G8H63Nq2h+Spap%bQetb1OTQ2GeevSiM8TPghU!%twWw?Y*;x8oYmP#{D9Q%v;s{Yi+`s+Im+TP5;dOc3x{>7!kA%l zRh$LX-V8$lAzBf$Ro_c$FCd^?+g%1iNPM8JiSs?eX6D6E!bR_frgEdSr-jkb#7Y`a zn!(ej=^0vrEgln(`D`oxX0q?SHkF8tzJd>mf$KF{v3yjqGt|`cEQa?+-0ytXr%R~@ zRi}sRJEaxxi&D-I*$>c*?8&M{&%VwfZ7c>|79_C6fku1tv3JhtIzL-4WZ$)LX(ALE z6SB2(Eb)%V^oR@3F?#VzXn}n53!b!Z_=tK@-p($Uit63Cw5fi7kkGn!)oEEa8$!bJ zi6~G)eVdt@`qHE%wZ-8-0Gg(%lqYR~-rL5H{+;{A*j@q!|iQy9n@ z+Uu0)^wyKKNP@Ej?#$r-6MobZIp(A4`JXqf`H5+@Owz=7Bl5_CJOs+;o}-7BXQ;%U zXX1Q>eadIYd1uM>z?jb7w#TZTPVjG^oFn8_g$}vt=+xD90oB31KN1Diz3eq_TL}mH zK4=n=P}XloF(d6aI{h_D65f>M>tj6-!qc}vj&_GtOMjx$kf$Ya@aOwF3zt{Jf<@|c z$W>Ys-+a|5kdyY)w_6N0{ zMS51iRjy@oj3=2hEfdUy`9rhY{C5sr-R(SjL0j7XqgtI@KTTLI4OR4w&{p-w8ARAw z2j@h}lwj94J|r6JLTeju{SMifaMK?xN!^ijK5uHdZhkKFME_Xehia*$ZE@C zfw9onJ+S_5w;$w8)}RXVK}1vTr_;hW(~ zzVduR!9+SUWazx3Y6X-HXmgmv|e@V zEPbrRa?Vc0>`{Z98_&R)joE^z9f;h3HLiTc?yULu*5UDF0Y}x&y1uH#HeXMc8|eN+ z#f6*~0}FfaUHQh^`%RdM1?B;1;^&Pu#7f}@kLcQd=m<0~hE~czj^kTvWu|k(L}*H~ z1uP>>!cJqFrI|@$?&{BO8?J@M%Du5$Gb&U{(yMQC6SL==kchbCbB&>rU#H9+qeB9Ts~$=B>?k7rL-S*C-LGQYR+CxWsH**P%&SE^FR!mQBl$u5g<9oC1Q}R<80++da-oHs zN)0f7%5~AWlNsK_=5uMjsigAZ&$h4+-_kkhVjar1;%eYean6BL-m5EKw`}MTxnq?) ztbk{Qr+qr*VlPIfvwEJ8t+qXXl}#r+C>5z4c;EDbE7ma9r}zCLGhn2-|EzZQteOuU zdUmLDu5i>eF6B>IyPz?rJhbS(5}9Pg=lWl#^Vn(=Rv1v0YOM65T1*W$zZ}!OIa>?n zkIQ1JAa@@ZesRJ_Je5fn=YEB=-*QFjemIZsIC z4L1derSKu;Z}P`V;Rl60C4!xO$KlPa<$qrFwpGzAWE**(mX;t})@-z$50ys*(@hTt zw`SZFEq5MP%wlQPkbOR-hxZ|Vmx4yIk^^|y7vNg==TtmfFw^RIGgV-RL2#APJpB0@ zhh(q{*MyyvVt8T3U-i5pE*WbCywnexf6P#@>UC_(izjq5n}wPxG##9b6-@Xxkd9hf z4T1i+4XJEG12m$raKH4uDtJ?zcL>KBCCJCD(Rd@ciu$98xz!xSY5ps#An423vXJRe zrZhNDar!iYyCuiwT1(a6G_LLXj*jI^1ySH_ilw}5m}SqHlzH3+ozcJJ#!jT}nN_5# z-q3Ad=CFfNil5q^e5OiU=G43Ke{2()P0K>7W*)oqc{s~$PgcY>w_a*RT|C2X&u9~0 zSqqRzW9Uj7q*a*e#8H4Kwj`&esr=`8Lm|Vte=W8Er5v{XsDzZ9!xx4s!<&`A_-SFe zD<4@xC~8P3x(&Sj{x#@NrSRnsZfkveo6DQn0)!2`iY;hsB6-fjlJ@N|V=icyEmjn= zdeN;`bNbuH&8g<&t{Ki>o52{4)(I(Q0|dI?^TdDqE+dGx+z_aseJsro^aI~*Y;&Mj zd@sy`tX#WEw&5IWuS9a* ze*fs8D0}u=rs4|bOkNqA)Hw~|o>M)a5NAgin-7~`7owIgU6%`5?1|0eDzVwUxW>kw zSSTTZX-MEOFRLr$h24(nh}LZLLy>T)3&h;CToo-(qE{D?Q=lRT!jspF z;o}5(GDy+$Xb^mgoY#I`yO1>%FVqR?p+fPNd|;^~*QnBs4^_5=FT9K7J`dK=D+G zfHnYveLKW$w2esiS3O5oj4c;;cXrT{=Su#3HTXy1LC{J2`#-eu6{O3-?qpt~iiNYn z)+~WmveorF-_ZO$I(3>!=9?Ts|LZ_Q!oW$E##=FjJ;ePuz(wL62O_DdHC{V+I**Zi z-c|b+XeUff&TrQK$^F|+qF!7Rz6Yt6w}5rOvAMg7eU)`{Oc;jPr}e+*@ft!pD|Zv; zE=}u>>v`MNs;$BNZ>)|E-z}>TTcZsKQ@{1YQdFppROzl&rSik?<+a!0}(Q)b2P;R|1+VVthq_+qJNGBknR zWc#+-ti`~8Q`S)I?L3t3=gm4wglf(7>|V??=Gb!i4aP*e!L`9@K8}+APSXHooGzj! z%m~3gb(SXtW9VRI?ih&i8aVt$b@i(;{95XK)l67~&d_b071>1c2OumOkK(Eg@5v+; zV#i*R$&@97PND@0iIn^@!hokb_!A8)0+!!dVXD((O~yW|w}l8zg8#)`uEgOPU%ISU z$*!mu6ck)kkB6@iZT_seLx6(pw->F3m-EBbZ(3jCiq)DI!1)D3L&ijNPc?b`dK%z# ze_NtowLPpG%9=61Atlqp=}tzSG~s*bao@&X;93$x%tW|yyg;R z;j*=KMRTw0H{R)B8L2sDT!t-ash46yD3^nMJu`JwH3;9-hpY2;NYFHB*_;1L$oe-- zN`bEC%$zSO{5Hj3A%e+UPn*`~UJcxW0lvT-zw@b6CGcpO<@lLJ+w_T3Ps@QIKN&wz zcZ-bJ?!CrNu14mVn9j3FOhe0}P7KuKsx)~0-Ms6a@#fPf4g&#}7RWNZ7J0w3B2;mY z!IPwLH>VFOe@AW??#O;zE0I5R>9z8Xrz(L_$yemEE~ z103cj@%RIdQ;)PcB9?X~B*P+l?z+^| zBi7bT<$b_vD`cXn(HWg3(`Hwe?l+VkPM7YTuUt2Q;iY%^etgXiPTMKx)Tr;pE- zU^E*qj#+Zw&#~H-Ega8_#x7LO$67L|TRt{FaBm2ZDY`!D81g`y36sx!`hnYRCB)Zetli5h5TR<7L(-6&t#uYigZpnhF?XeSLi;eAL@M2twlZAm& z$?h+)Jx28zBaJJUZE7nB%0;e$S`UY)Ma5t}@M5nI^)(jp*s#m<<4fOT{tw}ce17WS z{kU}tp-IXGszr>?x?6#uKe*Wx1mx$(v(Jae(<9iXZ4xTFc{K~PM9=gWTX**Os&bQ_ zxp38%Pk5FY@4TT$fHvEs`B$e8RJ>k$#Q%Y%aoA|^+j649anf#)Xgq<2vYyhezIrDO zuh4#9Qzqy2GvduFjEhY1e4WhW1!2yp!FHpDT7dGA%jbhGb!DeO8Y7)c@%N(GtM^Za zGG?xKN)rW?mAi@q%Ekjca){9W3){O;(7wv>9k2n4bmi;pea#6m{Jir%>6dCLh(b+eYx zl~my6UbC8W1oqr#)r{#y2KiO3gUGJ}?N(y4P;ulFDTeo3zNcXqxt zyYpO)kEwvPNYTCwT<8+m*1|c}_;CyF?q^b_1t?eNN(D2ezZs9W;pYMjgsfO_MDZ1% zd-G$sQ`TFHlZ(Y3v5%JbRN4@##?rzru#SA?&{OhDJ$gRhJ&J(Xs?U zjTa4Jr%(c-AjGOV^YR-*ze*Sxuq?k@65iAYJoynxoBILy^oADu^CcNd!a}L{DWR0t z2euZz9pJ*lcc+C^N`KY@H6m999M5PS-tztx_KbOL3d3OvhVnlfdG?gfNV5a6BUAvw z%#C{Fex12dij(_REP{+sGs(GbSAPy$!+ZtI;A){0_W-LmzL?IfYI2$j(MiHht5DcU zk#xkP+HRQPSgdzGXiH-pr7bLPqh2P=v$~ltWJ|kDcs`?QyKOgr%!C|_jO)oEU998P zer;h0`>V-|ViWX%A?zEbX6aJ>y~N*Tp8JGj>Q9 z9Ct(abe4VdH*1Tja>XOK^EzlO^ic=f6R3y_6#Hh+?l_uTuLRY0ucK{iIZ-pY*u{FC zfIp3!Dq;f^Z#kSPElgq6-_Hgs-4e*MFV`+<{tv1f^->iX$f5T=D}V=Udy|xZ#udQO z#(Lo_L`TexTFhaHAE`g(+;8`PY~YloSgbsc)>eb!hZ$ag?RImVOu!kjt}@2wuW(s} zb}wz4=CFW3>xy28W*qzn2jZA6PE>l(VXS4D%L+zx9PE~SoAYxd$xXu>YF@dt*7E&^ z;NO-t%C>0&r=SmIM`u1<)?aWCcEf0#a#3$V@VFYT0bbPT_&JdRME{1TycG25kg(w* zjI^dKj?n+wf^Zb{75SHZk^;FM_&05&y0*ntxRMdd=L3Llmjs-oZD39qIqn!EV<3m{ zM9z)e0p{9kmMCY=rjGYsnsewM^a$Lq=6*|t%qkabt>wN+zE{7*%xB@|Bt4%jID6Bq z0Lv$NkNL2OTW_HW<$&h#gBNrNyo=Uc9EA7iN|GXSIjT<|YcktdKjoF=RQJ^!__&v?xaGt&uVfh>(Lm$16e;%R{FI=EKExJw)p$i$*seYe>@!M>&4uvztINVaoiL# zYyiur6qKCiUQmtf!v`jR4bG&_b@vnB{7oyzb#D99vv{e%KHYbms}@sGe0VZ+Ik3uz z7@yM5^h{l34y=K_*6jY4Hjqck;$*&LL8^-Wwo0DmgV94y%fWV=m#C8(_J5n|_xp{n zsq`-=sF$!g+jKZyah7lCjUG?yuNnzWm>d|nFutaYl~-PlpDp-9(yY)i`06Rph&_`A z4c^eGlQm(g=84=Xep!oJ*N>eTNxml~mA9+}B~(e&ev=jS8|xks0tK(WlajjS+a0R@ zI;{5eFYSX@DZYsnBirU3Ce6^``-ms zxao;`gnbs#+x#;^(Uy610~Ko8vp$xay1kRDyH8qt2LALS7g`~v*CC|V$+{6eY)V~j zAt7bfom-106u-^}88n<03)?v6l!*ik)O{A!GA)0dSBi2@w5`?j=(;W=c|O_AL5E0} z)zT=LYXM$B+9+CUBZ}O&rpx zpCIrZ33P2qPlv}+Z1sS)FE1?Or;#BtE+T-?s$vhHR5nmbwikiSM3ZKwz5AZ~|M?FQ z%It&)Ds3$0onFvv_6r7~(6!P<8K)S!^aZN@k1rzCCiicT=}*;L<$JEe@-Ga}f+{%7 zZ*6;KlocvZcJ9;}r`4tY09PEc$*=>^@wgdSH~}*R2q8D$2kwDQIb7jUvjnkAT#bx? zAC=wONQ8if6+h_y(2x5Pskh3^;>y=mZC)^&!18Cahi}1NhGXuf#@oYeYCbpJlw%J|6K@?~S9qEM zU~*5AmCTSUXg7{vrb95x%6tk@;Ah|nJ3h%Cf7kJ4l@*a1nO)~wLoPGT-FvgL32$aY z43lb;U%7xUw-Udt?1}Tm@@;8RUzR*37m+;RzednkyxSbG<(0xfffsJlztz97(>!_^ znoSsM$yR`k=S(phC7ly>WH)S%;eg2%VhVkL8{YeaKc*lF%Hv#c2=5j_K!UGk)~-G=40T~6Mt-k1zY!|(&Rnc8bZVGp?e0~%j_lfG6lIwd4D)H zcLzTmpVg)!BYcH*KLMpt7>LLIz`Ly`z~n7&-j#EoyBl*Fwp6%U zHKBkZL8FENzukUEF#mdu=?8pAq9aY%4JJA-XtRI5Ou_3RL7jsKemSetbE0}9YjwAp zG{_pamP8F*w(?8{80nn&4gG4Z*?L)`QrhTuOt7i_AUF125~dB#H3R~Kgv$tj`Nus2 z=bn`@c}&d(!pTRuu3Ar|5vA(#hXy2>qh|&Hb>*z5e7_cQ*iT+RpStcR1asIVB-W)$ zLPUjuK}LTz;?8sqGS8qFywaX8+RHm9R2x5YqxXulrT5vZb=iEE0kJ04>TRUQ?g0Jy zH#h6egZ(iH0^9-^ z^pnwB>Btx+T{I;|CHcPQADOXQ$@K38*z|;>uD9wav9D6Uh z7=A~kUne#)y`(nX2&<;bae4LZgq~RPlpKuvaO5&in%wM!%*hIy!2eo?`>ZVr%nB(y zk!HZu^Ax*39}_o`${qr{MRFiyvM#Sx?G*>MM3n@j7q*rZ>TJ?5C=cw$b=`H|y`gVksT)!ys-+|JAr)9xdT2&c&ubHI+8l5`rE z#^TSTko5rwVO9a$7PEI2>HK*D|NA~s1A%p6d8jVn>$1SyM(7(D$%gB6gs%s$>a+zCuM%88TTVYE5jp*;rwn3%9 zj!%&JV0V2~W@Cj#@29cXElrqj%JFxCWVGt`xDuKY(!$8J)*`o6%buNrjjUtuY~WhH z_gl(WyPl#!LE5$2>I-$etO!?Kz;93)lu=A0K_231xs0h%el57>7vGk^DFs3lvW)qt zw~@HOaXj!MwQo0ITy2=zLUxRuJuhk@oUr@BHvDY$(s$xwBI0s_`Tsc}3ybW}0p!%& zTGe0^m>)wf)o7|$w_+Lz;vg}z?qws*X6I{t20S1!OdG8`2Ac8(!n{TQ>rgN%Mj`MQzUFx z!&Ay}5!FelzJJZ-zrN^XU^J!5!a}wsajtgH@-;>D!=6@`g=eqbyj2|kvK1UZtS;Yg zBnELK^K;fy?g!#S7u3yv zC0v>twfzH_QRN%Wr2s82m-QbM*T;bV2bO7)A_w!fnNtN``e$`~$}fO-)5ql%<2;+z zR2f4yUe?dvbGVVo32sz36all;nlVN5&j}{Y*@c1VmrE%3shQN1qy7iUa%wr+T*aU>kyCtxB^EV7 zAk#q>AxkLNuO!>ECc__=z(Y?iVGX*DVQRPC1G8{b3Fx^|a){@T>%E|6mkADanex|s zJlwH`=r2=76Xil$7bTAXi|hyvQRmGUwdJp#g9jmts_o)fSb}dNJf$`4VeU*#3 zoev+zf0|;Ju;d|1PEb0x(*qOoql%za=BPcZq!->p`hBgKf!x^3gUMyuR<8g{vd^(F zzf@k7KcZ;i87Sp7XX-VWt!^xf%XMZR`a?9s->;H^M+MY*30o(ET*M>etSr0Spb#>_ znO;c$%XD%^n`E23#LT#u1^(0Jc5Lkh-via@lb5&Vw$24}C2!`L=a|YITyPoDi-K^N zQag{kKCp^~>c+5R?Jp@+5tkLG(X@U;ThHr0eEr6T@Tc4SVJN0nAZ2=Nw(v|GqOaHM zF#9$IBL?cSz?1#W&5`c*j!O;!E%3 zykRpeW+utkpX1sVW>baBc#;{C7Z`K8RMG2Z$fNvyPgo>DIa=!O z%ebpyL~@4wQX^NS)rf;2tK zcituc?0rmK1at^{P_-G5qIQ%t!HepFBKs@1h5D!YZg%OqZhC7}_v9Wn$s74idfN)# z0@7?qnd~;6#0%QH)_8n*C>g+rhYqFGN&%KCSfxTWZ*(IM|7i=-gvqh~GUlp=Wb5U}RNUV%YJSIk z&I}0zDi^qa3Jz;D3GB zLXypxH<>4pY#fD)Nwjf49;U*NGxB0a4GO_mKXWwThDA}g+dPzx$OJfLu2d7fm%^TT z;-4)8+AOjrE{KQ!iKNDHBtYblQdcE}fx9fwlCeq~*#lj`??#L_2N4aa+4$ zUL^eNN1O`PfQ{6DX)DX@5?cm8uwY;#{#y#_te&@w;@7r602BCix_79(^FSK4h_=9@ z)u|Q%cA!EvzJBX%=SYNo5Y%hV05!!Ft<`h!SgK^2gRTw43XstjC+OSlC>NQ^x!Q{U z;n~U`NR>K3j(T--5Xj|y6{;)>RNTuaK;M$O?@$1DSg`Z?rUPvJH(Opuj1kdfUTg|B zoCbq&yG-SkQ&^@T&_c~Z9yk!fq+FG8Piy-gzzzw(+Vtr4d)sj) zToZuRjOk7xDEmVXEY6KU_67;SdORiJT6TYFR)HCS2Sh=xE`T1OTL7OaeSXj2ze|6= z1`+@OSccYj0WiZ86OoHtwewiDh<>bmuD=yp004oICzIBvo!!s=pTTeC;!1aW9P7hQ zzK0<|fM5{7Y_NZ^Ddl(1vi=kF`|ojW0gMtTz{<3q=cl9zneguEKS96$ncE70^j|dR zEXk$;?C=z;hrNr7HmXaU_K&f~vp7De}zYmWIfNXG&7c!h769@Q=)>&dp5EW|zK!WSf;|##4&88gy z004Jzp1%)|6#$`U8{h~}E&v};0f5!e>a_Iw>GSvDWdd;5NAm{91@Hk00IYCJ!000ES^2CGU==vPjz|xm;18=| z`F;v?x<{9m^db_U%!5R|C0{<|2VoP0NOGT zvj6}9F}Z5ddVX~;3BWkICIEuA3cxG?0Kj~#{rpx)0LIb%O@P{S%A#=efdByJAwGCy z@aMf2W9eQ25VTbQfou=}U=9}FtpR^Omhj&NKpOx6U_AKV^%;@?NLHu)UT7IEkt~#< zKmdSocn=P+ajpddA0J|a&Mwa*Z)637tv;;u# zC^va>cT|xw767b_wJpo5+jBBC0T6o41B~Em=AeKv06=s+9>OMv5HK)30T8UEr-)q5 z9vq=m1^}}t?M{zl*|bjc5+VT*X3hl6fH?#Jn1Sa57f=7W36lT_rPm&5E?_iQa|p~J z0DzU!mh@`5()Z6#s0855S-G;#3iNoH_Q*SDP8|;BM}K zc>{pG@apm3ktl4y1E~^#BX_;>o#qH?Vcq~>4@$D%^8tzT{_Bu70l4F;EVc7Y!4tgq z0ALNg`}=p$_g{E06(fPh?;-IzakhWP`?%;?^}z0 zYnVTP$c*mo3-tG)Kmw4MyRscK2d}xCKR82a6nZeaw=dA&hchGq8*yj&%pB;sn@hmS z!Wl^>M<(a=_u&;J0Pt7}WcEM@GYL=_-QVZzKGZ@20FT@`5#|qc+|48$!Evv_=>9&v zeQ1RQ08}u4(2w(c0wvshNDGLw`w$%x02l)E2maK&7vU;-AW-;I_IX4COm82&AOV1> zFn`d8QVzaq)i9sn!?9H}q9sOeAG{#}fUq!sFqWtJ1cJ^R$0Ui~K4gak0Mfzy!FXQ1 zH(}*+^s!~axqVm(5&+naR}PiAgz=o`8D<5&k-hvg^!8yiBme*>?nI z&aH)6?H>C35CIYZfJsCNMlc3?{16Qi0DvI8Dg!fvkQ2`DLsUor0FqD}K$uC00sVc5 z4haCjV%(LOnMLT&@i-&>eTWVT003O_Du6MwKwlqnLjnK*DxON)tIsRY*M}7#0RRAZ z&Vg~IxdiF$)7OU;AOQe?A^iB`kMDoV3+@uu;fEjo9}0sRcxqR80ssI207*qoM6N<$ Ef})JNGXMYp literal 0 HcmV?d00001 diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index fe36a8e9e..463329011 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -60,10 +60,12 @@ struct SDParams { std::string vae_path; std::string taesd_path; std::string esrgan_path; + std::string controlnet_path; ggml_type wtype = GGML_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; + std::string control_image_path; std::string prompt; std::string negative_prompt; @@ -121,11 +123,13 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); + printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); @@ -195,6 +199,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.taesd_path = argv[i]; + } else if (arg == "--control-net") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.controlnet_path = argv[i]; } else if (arg == "--upscale-model") { if (++i >= argc) { invalid_arg = true; @@ -238,6 +248,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.input_path = argv[i]; + } else if (arg == "--control-image") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.control_image_path = argv[i]; } else if (arg == "-o" || arg == "--output") { if (++i >= argc) { invalid_arg = true; @@ -484,14 +500,20 @@ int main(int argc, const char* argv[]) { StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.controlnet_path, params.wtype, params.schedule, params.clip_skip)) { return 1; } std::vector results; if (params.mode == TXT2IMG) { - int c = 0; - input_image_buffer = stbi_load("assets/control.png", ¶ms.width, ¶ms.height, &c, 3); + if(params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) { + int c = 0; + input_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); + if(input_image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); + return 1; + } + } results = sd.txt2img(params.prompt, params.negative_prompt, params.cfg_scale, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0f71d3d1f..419c817b1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -6080,6 +6080,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, + const std::string control_net_path, ggml_type wtype, Schedule schedule, int clip_skip) { @@ -6338,6 +6339,7 @@ class StableDiffusionGGML { denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]); denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]); } + LOG_DEBUG("finished loaded file"); ggml_free(ctx); if (upscale_output) { @@ -6345,10 +6347,14 @@ class StableDiffusionGGML { return false; } } + if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } - control_net.load_from_file("models/control_openpose-fp16.safetensors", backend, GGML_TYPE_F16 /* just f16 controlnet models */); + + if(control_net_path.size() > 0) { + return control_net.load_from_file(control_net_path, backend, GGML_TYPE_F16 /* just f16 controlnet models */); + } return true; } @@ -7184,10 +7190,11 @@ StableDiffusion::StableDiffusion(int n_threads, bool StableDiffusion::load_from_file(const std::string& model_path, const std::string& vae_path, + const std::string control_net_path, ggml_type wtype, Schedule s, int clip_skip) { - return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip); + return sd->load_from_file(model_path, vae_path, control_net_path, wtype, s, clip_skip); } std::vector StableDiffusion::txt2img(std::string prompt, @@ -7221,7 +7228,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, int64_t t1 = ggml_time_ms(); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); struct ggml_init_params params; - params.mem_size = static_cast(35 * 1024 * 1024); // 10 MB + params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; diff --git a/stable-diffusion.h b/stable-diffusion.h index 03d1eff0b..f32f82bce 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -50,6 +50,7 @@ class StableDiffusion { bool load_from_file(const std::string& model_path, const std::string& vae_path, + const std::string control_net_path, ggml_type wtype, Schedule d = DEFAULT, int clip_skip = -1); From 7d973f85b607a68095bb404479d0d58ce226ff9d Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 31 Dec 2023 17:42:57 -0500 Subject: [PATCH 07/29] control strength cli param --- README.md | 1 + examples/cli/main.cpp | 11 ++++++++++- stable-diffusion.cpp | 39 ++++++++++++++++++++++++++------------- stable-diffusion.h | 3 ++- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2de447d54..159f4277d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) - Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) - VAE tiling processing for reduce memory usage +- Control Net support with SD 1.5 - Sampling method - `Euler A` - `Euler` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 463329011..09babe4cf 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -79,6 +79,7 @@ struct SDParams { Schedule schedule = DEFAULT; int sample_steps = 20; float strength = 0.75f; + float control_strength = 0.9f; RNGType rng_type = CUDA_RNG; int64_t seed = 42; bool verbose = false; @@ -135,6 +136,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); + printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); @@ -284,6 +286,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.strength = std::stof(argv[i]); + } else if (arg == "--control-strength") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.control_strength = std::stof(argv[i]); } else if (arg == "-H" || arg == "--height") { if (++i >= argc) { invalid_arg = true; @@ -523,7 +531,8 @@ int main(int argc, const char* argv[]) { params.sample_steps, params.seed, params.batch_count, - input_image_buffer); + input_image_buffer, + params.control_strength); } else { results = sd.img2img(input_image_buffer, params.prompt, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index aae038931..9d86892d6 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2700,6 +2700,7 @@ struct UNetModel { struct ggml_tensor* timesteps, struct ggml_tensor* context, std::vector control, + struct ggml_tensor* control_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] @@ -2757,7 +2758,8 @@ struct UNetModel { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] if(control.size() > 0) { - h = ggml_add(ctx0, h, control[control.size() - 1]); // middle control + auto cs = ggml_scale_inplace(ctx0, control[control.size() - 1], control_strength); + h = ggml_add(ctx0, h, cs); // middle control } int control_offset = control.size() - 2; @@ -2767,7 +2769,8 @@ struct UNetModel { auto h_skip = hs.back(); hs.pop_back(); if(control.size() > 0) { - h_skip = ggml_add(ctx0, h_skip, control[control_offset]); // control net condition + auto cs = ggml_scale_inplace(ctx0, control[control_offset], control_strength); + h_skip = ggml_add(ctx0, h_skip, cs); // control net condition control_offset--; } @@ -2801,7 +2804,8 @@ struct UNetModel { struct ggml_tensor* context, std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength = 1.0) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -2823,6 +2827,7 @@ struct UNetModel { struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; struct ggml_tensor* y_t = NULL; + struct ggml_tensor* control_strength = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { @@ -2866,7 +2871,12 @@ struct UNetModel { y_t = y; } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control, t_emb_t, y_t); + ggml_allocr_alloc(compute_alloc, control_strength); + if(!ggml_allocr_is_measure(compute_alloc)) { + ggml_backend_tensor_set(control_strength, &control_net_strength, 0, sizeof(float)); + } + + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control, control_strength, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -2883,7 +2893,7 @@ struct UNetModel { // alignment required by the backend compute_alloc = ggml_allocr_new_measure_from_backend(backend); - struct ggml_cgraph* gf = build_graph(x, NULL, context, control, t_emb, y); + struct ggml_cgraph* gf = build_graph(x, NULL, context, control, t_emb, y, 1.0f); // compute the required memory compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); @@ -2905,11 +2915,12 @@ struct UNetModel { struct ggml_tensor* context, std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength) { ggml_allocr_reset(compute_alloc); // compute - struct ggml_cgraph* gf = build_graph(x, timesteps, context, control, t_emb, y); + struct ggml_cgraph* gf = build_graph(x, timesteps, context, control, t_emb, y, control_net_strength); ggml_allocr_alloc_graph(compute_alloc, gf); @@ -6543,7 +6554,8 @@ class StableDiffusionGGML { ggml_tensor* control_hint, float cfg_scale, SampleMethod method, - const std::vector& sigmas) { + const std::vector& sigmas, + float control_strength) { size_t steps = sigmas.size() - 1; // x_t = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(x_t); @@ -6611,7 +6623,7 @@ class StableDiffusionGGML { if(control_hint != NULL) { control_net.compute(n_threads, noised_input, control_hint, c, t_emb); } - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, t_emb, c_vector); + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, t_emb, c_vector, control_strength); float* negative_data = NULL; if (has_unconditioned) { @@ -6620,7 +6632,7 @@ class StableDiffusionGGML { control_net.compute(n_threads, noised_input, control_hint, uc, t_emb); } - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, t_emb, uc_vector); + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, t_emb, uc_vector, control_strength); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; @@ -7206,7 +7218,8 @@ std::vector StableDiffusion::txt2img(std::string prompt, int sample_steps, int64_t seed, int batch_count, - const uint8_t* control_cond) { + const uint8_t* control_cond, + float control_strength) { std::vector results; // if (width >= 1024 && height >= 1024) { // 1024 x 1024 images // LOG_WARN("Image too large, try a smaller size."); @@ -7293,7 +7306,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, std::vector sigmas = sd->denoiser->schedule->get_sigmas(sample_steps); - struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas); + struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas, control_strength); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -7440,7 +7453,7 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, ggml_tensor_set_f32_randn(noise, sd->rng); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, NULL, cfg_scale, sample_method, sigma_sched); + struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, NULL, cfg_scale, sample_method, sigma_sched, 1.0f); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index f32f82bce..8a3a17049 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -65,7 +65,8 @@ class StableDiffusion { int sample_steps, int64_t seed, int batch_count, - const uint8_t* control_cond); + const uint8_t* control_cond, + float control_strength); std::vector img2img( const uint8_t* init_img_data, From 81a9c6495aebc502866f48b1fba43c7e7dfa4e80 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 31 Dec 2023 17:51:47 -0500 Subject: [PATCH 08/29] fix ci errors --- stable-diffusion.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9d86892d6..9e6efda13 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2914,9 +2914,9 @@ struct UNetModel { struct ggml_tensor* timesteps, struct ggml_tensor* context, std::vector control, + float control_net_strength, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL, - float control_net_strength) { + struct ggml_tensor* y = NULL) { ggml_allocr_reset(compute_alloc); // compute @@ -6384,7 +6384,7 @@ class StableDiffusionGGML { ggml_set_f32(timesteps, 999); set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model.compute(out, n_threads, x_t, NULL, c, std::vector(), t_emb); + diffusion_model.compute(out, n_threads, x_t, NULL, c, std::vector(), 1.0f, t_emb); diffusion_model.end(); double result = 0.f; @@ -6623,7 +6623,7 @@ class StableDiffusionGGML { if(control_hint != NULL) { control_net.compute(n_threads, noised_input, control_hint, c, t_emb); } - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, t_emb, c_vector, control_strength); + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, control_strength, t_emb, c_vector); float* negative_data = NULL; if (has_unconditioned) { @@ -6632,7 +6632,7 @@ class StableDiffusionGGML { control_net.compute(n_threads, noised_input, control_hint, uc, t_emb); } - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, t_emb, uc_vector, control_strength); + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, control_strength, t_emb, uc_vector); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; From 30037eca7cda6c388e829126cb31b55106b57fa8 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 31 Dec 2023 19:50:19 -0500 Subject: [PATCH 09/29] cli param keep controlnet in cpu --- examples/cli/main.cpp | 6 +++++- stable-diffusion.cpp | 39 ++++++++++++++++++++++++++++++++------- stable-diffusion.h | 3 ++- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 09babe4cf..dd1fb852f 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -84,6 +84,7 @@ struct SDParams { int64_t seed = 42; bool verbose = false; bool vae_tiling = false; + bool control_net_cpu = false; }; void print_params(SDParams params) { @@ -150,6 +151,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); + printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" -v, --verbose print extra info\n"); } @@ -318,6 +320,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.clip_skip = std::stoi(argv[i]); } else if (arg == "--vae-tiling") { params.vae_tiling = true; + } else if (arg == "--control-net-cpu") { + params.control_net_cpu = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -508,7 +512,7 @@ int main(int argc, const char* argv[]) { StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.controlnet_path, params.wtype, params.schedule, params.clip_skip)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.controlnet_path, params.wtype, params.schedule, params.clip_skip, params.control_net_cpu)) { return 1; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9e6efda13..3e275e7d2 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2828,6 +2828,7 @@ struct UNetModel { struct ggml_tensor* t_emb_t = NULL; struct ggml_tensor* y_t = NULL; struct ggml_tensor* control_strength = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + std::vector control_t; // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { @@ -2871,12 +2872,27 @@ struct UNetModel { y_t = y; } + // offload all controls tensors to gpu + if(control.size() > 0 && !ggml_backend_is_cpu(backend) && control[0]->backend != GGML_BACKEND_GPU) { + for(int i = 0; i < control.size(); i++) { + ggml_tensor* cntl_t = ggml_dup_tensor(ctx0, control[i]); + control_t.push_back(cntl_t); + ggml_allocr_alloc(compute_alloc, cntl_t); + if(!ggml_allocr_is_measure(compute_alloc)) { + ggml_backend_tensor_copy(control[i], control_t[i]); + ggml_backend_synchronize(backend); + } + } + } else { + control_t = control; + } + ggml_allocr_alloc(compute_alloc, control_strength); if(!ggml_allocr_is_measure(compute_alloc)) { ggml_backend_tensor_set(control_strength, &control_net_strength, 0, sizeof(float)); } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control, control_strength, t_emb_t, y_t); + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_strength, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -5345,7 +5361,7 @@ struct ControlNet { ggml_allocr_free(alloc); } - bool load_from_file(const std::string& file_path, ggml_backend_t backend, ggml_type wtype__) { + bool load_from_file(const std::string& file_path, ggml_backend_t backend_, ggml_type wtype_) { LOG_INFO("loading control net from '%s'", file_path.c_str()); std::map control_tensors; @@ -5356,7 +5372,7 @@ struct ControlNet { return false; } - if (!init(backend, wtype__)) { + if (!init(backend_, wtype_)) { return false; } @@ -6094,7 +6110,8 @@ class StableDiffusionGGML { const std::string control_net_path, ggml_type wtype, Schedule schedule, - int clip_skip) { + int clip_skip, + bool control_net_cpu) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -6364,7 +6381,14 @@ class StableDiffusionGGML { } if(control_net_path.size() > 0) { - return control_net.load_from_file(control_net_path, backend, GGML_TYPE_F16 /* just f16 controlnet models */); + ggml_backend_t cn_backend = NULL; + if(control_net_cpu && !ggml_backend_is_cpu(backend)) { + LOG_DEBUG("ControlNet: Using CPU backend"); + cn_backend = ggml_backend_cpu_init(); + } else { + cn_backend = backend; + } + return control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */); } return true; } @@ -7205,8 +7229,9 @@ bool StableDiffusion::load_from_file(const std::string& model_path, const std::string control_net_path, ggml_type wtype, Schedule s, - int clip_skip) { - return sd->load_from_file(model_path, vae_path, control_net_path, wtype, s, clip_skip); + int clip_skip, + bool control_net_cpu) { + return sd->load_from_file(model_path, vae_path, control_net_path, wtype, s, clip_skip, control_net_cpu); } std::vector StableDiffusion::txt2img(std::string prompt, diff --git a/stable-diffusion.h b/stable-diffusion.h index 8a3a17049..0f0eea789 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -53,7 +53,8 @@ class StableDiffusion { const std::string control_net_path, ggml_type wtype, Schedule d = DEFAULT, - int clip_skip = -1); + int clip_skip = -1, + bool control_net_cpu = false); std::vector txt2img( std::string prompt, From 4c0fbe5f74be94f3c8d03fc26377e577512c3bca Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 1 Jan 2024 10:28:11 -0500 Subject: [PATCH 10/29] debug ckpt --- model.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/model.cpp b/model.cpp index 01b89030d..ed4ae4045 100644 --- a/model.cpp +++ b/model.cpp @@ -968,6 +968,7 @@ struct PickleTensorReader { } void read_string(const std::string& str, struct zip_t* zip, std::string dir) { + printf("%s\n", str.c_str()); if (str == "storage") { read_global_type = true; } else if (str != "state_dict") { From 249875db89e3bcc73a9dd652c5d5939b6135c574 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 1 Jan 2024 10:37:32 -0500 Subject: [PATCH 11/29] admit pth models --- model.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model.cpp b/model.cpp index ed4ae4045..92f0064c6 100644 --- a/model.cpp +++ b/model.cpp @@ -403,6 +403,11 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } + } else if (starts_with(name, "control_model")) { + size_t pos = name.find('.'); + if (pos != std::string::npos) { + new_name = name.substr(pos + 1); + } } else { new_name = name; } @@ -968,7 +973,6 @@ struct PickleTensorReader { } void read_string(const std::string& str, struct zip_t* zip, std::string dir) { - printf("%s\n", str.c_str()); if (str == "storage") { read_global_type = true; } else if (str != "state_dict") { From c7377003d4a432f72d51c968ca6b585cb6e4fde7 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 1 Jan 2024 10:40:22 -0500 Subject: [PATCH 12/29] debug name --- model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model.cpp b/model.cpp index 92f0064c6..d9444e2ed 100644 --- a/model.cpp +++ b/model.cpp @@ -403,10 +403,11 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } - } else if (starts_with(name, "control_model")) { + } else if (starts_with(name, "control_model.")) { size_t pos = name.find('.'); if (pos != std::string::npos) { new_name = name.substr(pos + 1); + printf("controlnet tensor: %s\n", new_name.c_str()); } } else { new_name = name; From fdf63e005671a91abe2c3d9eca7e4d845d22f6cd Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 1 Jan 2024 10:45:42 -0500 Subject: [PATCH 13/29] remove ignore controlnet tensors --- model.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/model.cpp b/model.cpp index d9444e2ed..d762d6632 100644 --- a/model.cpp +++ b/model.cpp @@ -86,7 +86,6 @@ const char* unused_tensors[] = { "model_ema.decay", "model_ema.num_updates", "model_ema.diffusion_model", - "control_model", "embedding_manager", "denoiser.sigmas", }; @@ -373,6 +372,11 @@ std::string convert_tensor_name(const std::string& name) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); + } else if (starts_with(name, "control_model.")) { // for controlnet pth models + size_t pos = name.find('.'); + if (pos != std::string::npos) { + new_name = name.substr(pos + 1); + } } else if (starts_with(name, "lora_")) { // for lora size_t pos = name.find('.'); if (pos != std::string::npos) { @@ -403,12 +407,6 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } - } else if (starts_with(name, "control_model.")) { - size_t pos = name.find('.'); - if (pos != std::string::npos) { - new_name = name.substr(pos + 1); - printf("controlnet tensor: %s\n", new_name.c_str()); - } } else { new_name = name; } From 5da72348ef12f2ebc0ba2dc6a82ea9d9868a26fe Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 2 Jan 2024 13:36:08 -0500 Subject: [PATCH 14/29] integrate refactor changes --- control.hpp | 642 ++++++++++++++++++++++++++++++++++++++++++ examples/cli/main.cpp | 13 +- ggml_extend.hpp | 2 +- stable-diffusion.cpp | 46 +-- stable-diffusion.h | 8 +- unet.hpp | 52 +++- 6 files changed, 730 insertions(+), 33 deletions(-) create mode 100644 control.hpp diff --git a/control.hpp b/control.hpp new file mode 100644 index 000000000..c15de901d --- /dev/null +++ b/control.hpp @@ -0,0 +1,642 @@ +#ifndef __CONTROL_HPP__ +#define __CONTROL_HPP__ + +#include "ggml_extend.hpp" +#include "unet.hpp" +#include "model.h" + +/* + =================================== ControlNet =================================== + Reference: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/cldm/cldm.py + +*/ + +struct CNHintBlock { + int hint_channels = 3; + int model_channels = 320; // SD 1.5 + int feat_channels[4] = { 16, 32, 96, 256 }; + int num_blocks = 3; + ggml_tensor* conv_first_w; // [feat_channels[0], hint_channels, 3, 3] + ggml_tensor* conv_first_b; // [feat_channels[0]] + + struct hint_block { + ggml_tensor* conv_0_w; // [feat_channels[idx], feat_channels[idx], 3, 3] + ggml_tensor* conv_0_b; // [feat_channels[idx]] + + ggml_tensor* conv_1_w; // [feat_channels[idx + 1], feat_channels[idx], 3, 3] + ggml_tensor* conv_1_b; // [feat_channels[idx + 1]] + }; + + hint_block blocks[3]; + ggml_tensor* conv_final_w; // [model_channels, feat_channels[3], 3, 3] + ggml_tensor* conv_final_b; // [model_channels] + + size_t calculate_mem_size() { + size_t mem_size = feat_channels[0] * hint_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += feat_channels[0] * ggml_type_size(GGML_TYPE_F32); // conv_first_b + for (int i = 0; i < num_blocks; i++) { + mem_size += feat_channels[i] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_0_w + mem_size += feat_channels[i] * ggml_type_size(GGML_TYPE_F32); // conv_0_b + mem_size += feat_channels[i + 1] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w + mem_size += feat_channels[i + 1] * ggml_type_size(GGML_TYPE_F32); // conv_1_b + } + mem_size += model_channels * feat_channels[3] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_final_w + mem_size += model_channels * ggml_type_size(GGML_TYPE_F32); // conv_final_b + return static_cast(mem_size); + } + + void init_params(struct ggml_context* ctx) { + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, hint_channels, feat_channels[0]); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[0]); + + for (int i = 0; i < num_blocks; i++) { + blocks[i].conv_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i]); + blocks[i].conv_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i]); + blocks[i].conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i + 1]); + blocks[i].conv_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i + 1]); + } + + conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[3], model_channels); + conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model_channels); + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "input_hint_block.0.weight"] = conv_first_w; + tensors[prefix + "input_hint_block.0.bias"] = conv_first_b; + int index = 2; + for (int i = 0; i < num_blocks; i++) { + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_0_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_0_b; + index += 2; + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_1_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_1_b; + index += 2; + } + tensors[prefix + "input_hint_block.14.weight"] = conv_final_w; + tensors[prefix + "input_hint_block.14.bias"] = conv_final_b; + } + + struct ggml_tensor* forward(ggml_context* ctx, struct ggml_tensor* x) { + auto h = ggml_nn_conv_2d(ctx, x, conv_first_w, conv_first_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + + auto body_h = h; + for(int i = 0; i < num_blocks; i++) { + // operations.conv_nd(dims, 16, 16, 3, padding=1) + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_0_w, blocks[i].conv_0_b, 1, 1, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); + // operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2) + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_1_w, blocks[i].conv_1_b, 2, 2, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); + } + + h = ggml_nn_conv_2d(ctx, body_h, conv_final_w, conv_final_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + return h; + } +}; + +struct CNZeroConv { + int channels; + ggml_tensor* conv_w; // [channels, channels, 1, 1] + ggml_tensor* conv_b; // [channels] + + void init_params(struct ggml_context* ctx) { + conv_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels,channels); + conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + } +}; + +struct ControlNet : public GGMLModule { + int in_channels = 4; + int model_channels = 320; + int out_channels = 4; + int num_res_blocks = 2; + std::vector attention_resolutions = {4, 2, 1}; + std::vector channel_mult = {1, 2, 4, 4}; + std::vector transformer_depth = {1, 1, 1, 1}; + int time_embed_dim = 1280; // model_channels*4 + int num_heads = 8; + int num_head_channels = -1; // channels // num_heads + int context_dim = 768; + int middle_out_channel; + CNHintBlock input_hint_block; + CNZeroConv zero_convs[12]; + int num_zero_convs = 1; + + // network params + struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels] + struct ggml_tensor* time_embed_0_b; // [time_embed_dim, ] + // time_embed_1 is nn.SILU() + struct ggml_tensor* time_embed_2_w; // [time_embed_dim, time_embed_dim] + struct ggml_tensor* time_embed_2_b; // [time_embed_dim, ] + + struct ggml_tensor* input_block_0_w; // [model_channels, in_channels, 3, 3] + struct ggml_tensor* input_block_0_b; // [model_channels, ] + + // input_blocks + ResBlock input_res_blocks[4][2]; + SpatialTransformer input_transformers[3][2]; + DownSample input_down_samples[3]; + + // middle_block + ResBlock middle_block_0; + SpatialTransformer middle_block_1; + ResBlock middle_block_2; + + struct ggml_tensor* middle_block_out_w; // [middle_out_channel, middle_out_channel, 1, 1] + struct ggml_tensor* middle_block_out_b; // [middle_out_channel, ] + ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory + ggml_context* control_ctx = NULL; + std::vector controls; // (12 input block outputs, 1 middle block output) SD 1.5 + + ControlNet() { + // input_blocks + std::vector input_block_chans; + input_block_chans.push_back(model_channels); + int ch = model_channels; + zero_convs[0].channels = model_channels; + int ds = 1; + + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].channels = ch; + input_res_blocks[i][j].emb_channels = time_embed_dim; + input_res_blocks[i][j].out_channels = mult * model_channels; + + ch = mult * model_channels; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + input_transformers[i][j] = SpatialTransformer(transformer_depth[i]); + input_transformers[i][j].in_channels = ch; + input_transformers[i][j].n_head = n_head; + input_transformers[i][j].d_head = d_head; + input_transformers[i][j].context_dim = context_dim; + } + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + } + if (i != len_mults - 1) { + input_down_samples[i].channels = ch; + input_down_samples[i].out_channels = ch; + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + ds *= 2; + } + } + GGML_ASSERT(num_zero_convs == 12); + + // middle blocks + middle_block_0.channels = ch; + middle_block_0.emb_channels = time_embed_dim; + middle_block_0.out_channels = ch; + + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + middle_block_1 = SpatialTransformer(transformer_depth[transformer_depth.size() - 1]); + middle_block_1.in_channels = ch; + middle_block_1.n_head = n_head; + middle_block_1.d_head = d_head; + middle_block_1.context_dim = context_dim; + + middle_block_2.channels = ch; + middle_block_2.emb_channels = time_embed_dim; + middle_block_2.out_channels = ch; + middle_out_channel = ch; + } + + size_t calculate_mem_size() { + double mem_size = 0; + mem_size += input_hint_block.calculate_mem_size(); + mem_size += time_embed_dim * model_channels * ggml_type_sizef(wtype); // time_embed_0_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_0_b + mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + + mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w + mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b + + // input_blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + mem_size += input_res_blocks[i][j].calculate_mem_size(wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + mem_size += input_transformers[i][j].calculate_mem_size(wtype); + } + } + if (i != len_mults - 1) { + ds *= 2; + mem_size += input_down_samples[i].calculate_mem_size(wtype); + } + } + + for (int i = 0; i < num_zero_convs; i++) { + mem_size += zero_convs[i].channels * zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F16); + mem_size += zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F32); + } + + // middle_block + mem_size += middle_block_0.calculate_mem_size(wtype); + mem_size += middle_block_1.calculate_mem_size(wtype); + mem_size += middle_block_2.calculate_mem_size(wtype); + + mem_size += middle_out_channel * middle_out_channel * ggml_type_sizef(GGML_TYPE_F16); // middle_block_out_w + mem_size += middle_out_channel * ggml_type_sizef(GGML_TYPE_F32); // middle_block_out_b + + return static_cast(mem_size); + } + + size_t get_num_tensors() { + // in + size_t num_tensors = 6; + + num_tensors += num_zero_convs * 2; + + // input blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + num_tensors += 12; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + num_tensors += input_transformers[i][j].get_num_tensors(); + } + } + if (i != len_mults - 1) { + ds *= 2; + num_tensors += 2; + } + } + + // middle blocks + num_tensors += 13 * 2; + num_tensors += middle_block_1.get_num_tensors(); + return num_tensors; + } + + void init_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + + input_hint_block.init_params(params_ctx); + + time_embed_0_w = ggml_new_tensor_2d(params_ctx, wtype, model_channels, time_embed_dim); + time_embed_0_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, time_embed_dim); + time_embed_2_w = ggml_new_tensor_2d(params_ctx, wtype, time_embed_dim, time_embed_dim); + time_embed_2_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, time_embed_dim); + + // input_blocks + input_block_0_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, in_channels, model_channels); + input_block_0_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, model_channels); + + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].init_params(params_ctx, wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].init_params(params_ctx, alloc, wtype); + } + } + if (i != len_mults - 1) { + input_down_samples[i].init_params(params_ctx, wtype); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + zero_convs[i].init_params(params_ctx); + } + + // middle_blocks + middle_block_0.init_params(params_ctx, wtype); + middle_block_1.init_params(params_ctx, alloc, wtype); + middle_block_2.init_params(params_ctx, wtype); + + // middle_block_out + middle_block_out_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 1, 1, middle_out_channel, middle_out_channel); + middle_block_out_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, middle_out_channel); + + // alloc all tensors linked to this context + for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) { + if (t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } + } + + ggml_allocr_free(alloc); + } + + bool load_from_file(const std::string& file_path, ggml_backend_t backend_, ggml_type wtype_) { + LOG_INFO("loading control net from '%s'", file_path.c_str()); + + std::map control_tensors; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + if (!alloc_params_buffer(backend_, wtype_)) { + return false; + } + + // prepare memory for the weights + { + init_params(); + map_by_name(control_tensors, ""); + } + + std::set tensor_names_in_file; + + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + tensor_names_in_file.insert(name); + + struct ggml_tensor* real; + if (control_tensors.find(name) != control_tensors.end()) { + real = control_tensors[name]; + } else { + LOG_ERROR("unknown tensor '%s' in model file", name.data()); + return true; + } + + if ( + real->ne[0] != tensor_storage.ne[0] || + real->ne[1] != tensor_storage.ne[1] || + real->ne[2] != tensor_storage.ne[2] || + real->ne[3] != tensor_storage.ne[3]) { + LOG_ERROR( + "tensor '%s' has wrong shape in model file: " + "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", + name.c_str(), + (int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3], + (int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]); + return false; + } + + *dst_tensor = real; + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); + + bool some_tensor_not_init = false; + + for (auto pair : control_tensors) { + if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) { + LOG_ERROR("tensor '%s' not in model file", pair.first.c_str()); + some_tensor_not_init = true; + } + } + + if (some_tensor_not_init) { + return false; + } + + LOG_INFO("control net model loaded"); + return success; + } + + void map_by_name(std::map& tensors, const std::string prefix) { + input_hint_block.map_by_name(tensors, ""); + tensors[prefix + "time_embed.0.weight"] = time_embed_0_w; + tensors[prefix + "time_embed.0.bias"] = time_embed_0_b; + tensors[prefix + "time_embed.2.weight"] = time_embed_2_w; + tensors[prefix + "time_embed.2.bias"] = time_embed_2_b; + + // input_blocks + tensors[prefix + "input_blocks.0.0.weight"] = input_block_0_w; + tensors[prefix + "input_blocks.0.0.bias"] = input_block_0_b; + + int len_mults = channel_mult.size(); + int input_block_idx = 0; + int ds = 1; + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_block_idx += 1; + input_res_blocks[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".1."); + } + } + if (i != len_mults - 1) { + input_block_idx += 1; + input_down_samples[i].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.weight"] = zero_convs[i].conv_w; + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.bias"] = zero_convs[i].conv_b; + } + + // middle_blocks + middle_block_0.map_by_name(tensors, prefix + "middle_block.0."); + middle_block_1.map_by_name(tensors, prefix + "middle_block.1."); + middle_block_2.map_by_name(tensors, prefix + "middle_block.2."); + + tensors[prefix + "middle_block_out.0.weight"] = middle_block_out_w; + tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; + } + + void forward(struct ggml_cgraph* gf, + struct ggml_context* ctx0, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // t_emb: [N, model_channels] + // context: [N, max_position, hidden_size]([N, 77, 768]) + if (t_emb == NULL && timesteps != NULL) { + t_emb = new_timestep_embedding(ctx0, compute_allocr, timesteps, model_channels); // [N, model_channels] + } + + // time_embed = nn.Sequential + auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b); + emb = ggml_silu_inplace(ctx0, emb); + emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] + + auto guided_hint = input_hint_block.forward(ctx0, hint); + + // input_blocks + int zero_conv_offset = 0; + + // input block 0 + struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] + h = ggml_add(ctx0, h, guided_hint); + + auto h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + + // input block 1-11 + int len_mults = channel_mult.size(); + int ds = 1; + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + h = input_res_blocks[i][j].forward(ctx0, h, emb); // [N, mult*model_channels, h, w] + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] + } + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + } + if (i != len_mults - 1) { + ds *= 2; + h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + } + } + // [N, 4*model_channels, h/8, w/8] + + // middle_block + h = middle_block_0.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] + h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + + h_c = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + struct ggml_context* ctx0 = ggml_init(params); + + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, UNET_GRAPH_SIZE, false); + + // temporal tensors for transfer tensors from cpu to gpu if needed + struct ggml_tensor* x_t = NULL; + struct ggml_tensor* hint_t = NULL; + struct ggml_tensor* timesteps_t = NULL; + struct ggml_tensor* context_t = NULL; + struct ggml_tensor* t_emb_t = NULL; + + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + x_t = ggml_dup_tensor(ctx0, x); + context_t = ggml_dup_tensor(ctx0, context); + hint_t = ggml_dup_tensor(ctx0, hint); + ggml_allocr_alloc(compute_allocr, x_t); + if (timesteps != NULL) { + timesteps_t = ggml_dup_tensor(ctx0, timesteps); + ggml_allocr_alloc(compute_allocr, timesteps_t); + } + ggml_allocr_alloc(compute_allocr, context_t); + ggml_allocr_alloc(compute_allocr, hint_t); + if (t_emb != NULL) { + t_emb_t = ggml_dup_tensor(ctx0, t_emb); + ggml_allocr_alloc(compute_allocr, t_emb_t); + } + // pass data to device backend + if (!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); + if (timesteps_t != NULL) { + ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + } + if (t_emb_t != NULL) { + ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + } + } + } else { + // if it's cpu backend just pass the same tensors + x_t = x; + timesteps_t = timesteps; + context_t = context; + t_emb_t = t_emb; + hint_t = hint; + } + + forward(gf, ctx0, x_t, hint_t, timesteps_t, context_t, t_emb_t); + + ggml_free(ctx0); + + return gf; + } + + void alloc_compute_buffer(struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + { + struct ggml_init_params params; + params.mem_size = static_cast(13 * ggml_tensor_overhead()) + 256; + params.mem_buffer = NULL; + params.no_alloc = true; + control_ctx = ggml_init(params); + size_t control_buffer_size = 0; + int w = x->ne[0], h = x->ne[1], steps = 0; + for(int i = 0; i < (num_zero_convs + 1); i++) { + bool last = i == num_zero_convs; + int c = last ? middle_out_channel : zero_convs[i].channels; + if(!last && steps == 3) { + w /= 2; h /= 2; steps = 0; + } + controls.push_back(ggml_new_tensor_4d(control_ctx, GGML_TYPE_F32, w, h, c, 1)); + control_buffer_size += ggml_nbytes(controls[i]); + steps++; + } + control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, backend); + } + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, hint, NULL, context, t_emb); + }; + GGMLModule::alloc_compute_buffer(get_graph); + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, hint, NULL, context, t_emb); + }; + GGMLModule::compute(get_graph, n_threads, NULL); + } +}; + +#endif // __CONTROL_HPP__ \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 03b98e2d8..e7311fc5f 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -538,6 +538,7 @@ int main(int argc, const char* argv[]) { sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), params.vae_path.c_str(), params.taesd_path.c_str(), + params.controlnet_path.c_str(), params.lora_model_dir.c_str(), vae_decode_only, params.vae_tiling, @@ -545,7 +546,8 @@ int main(int argc, const char* argv[]) { params.n_threads, params.wtype, params.rng_type, - params.schedule); + params.schedule, + params.control_net_cpu); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -554,6 +556,7 @@ int main(int argc, const char* argv[]) { sd_image_t* results; if (params.mode == TXT2IMG) { + sd_image_t* control_image = NULL; if(params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) { int c = 0; input_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); @@ -561,6 +564,10 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); return 1; } + control_image = new sd_image_t{(uint32_t)params.width, + (uint32_t)params.height, + 3, + input_image_buffer}; } results = txt2img(sd_ctx, params.prompt.c_str(), @@ -573,8 +580,8 @@ int main(int argc, const char* argv[]) { params.sample_steps, params.seed, params.batch_count, - input_image_buffer, - params.control_strength); + control_image, + params.control_strength); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 3bd77556c..480ea77ca 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -544,7 +544,7 @@ struct GGMLModule { bool alloc_params_buffer(ggml_backend_t backend_, ggml_type wtype_ = GGML_TYPE_F32) { backend = backend_; wtype = wtype_; - params_buffer_size = 10 * 1024 * 1024; // 10 MB, for padding + params_buffer_size = 4 * 1024 * 1024; // 10 MB, for padding params_buffer_size += calculate_mem_size(); size_t num_tensors = get_num_tensors(); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b45a4a2bb..24fa1eae9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -8,6 +8,7 @@ #include "clip.hpp" #include "denoiser.hpp" +#include "control.hpp" #include "esrgan.hpp" #include "lora.hpp" #include "tae.hpp" @@ -80,6 +81,8 @@ class StableDiffusionGGML { TinyAutoEncoder tae_first_stage; std::string taesd_path; + ControlNet control_net; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, @@ -110,8 +113,7 @@ class StableDiffusionGGML { const std::string& taesd_path, bool vae_tiling, ggml_type wtype, - Schedule schedule, - int clip_skip, + schedule_t schedule, bool control_net_cpu) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); @@ -338,14 +340,13 @@ class StableDiffusionGGML { struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - diffusion_model.begin(x_t, c, std::vector(), t_emb); - int64_t t0 = ggml_time_ms(); ggml_set_f32(timesteps, 999); set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model.alloc_compute_buffer(x_t, c, t_emb); - diffusion_model.compute(out, n_threads, x_t, NULL, c, t_emb); + std::vector controls; + diffusion_model.alloc_compute_buffer(x_t, c, controls, t_emb); + diffusion_model.compute(out, n_threads, x_t, NULL, c, controls, 1.0f, t_emb); diffusion_model.free_compute_buffer(); double result = 0.f; @@ -528,7 +529,7 @@ class StableDiffusionGGML { ggml_tensor* uc_vector, ggml_tensor* control_hint, float cfg_scale, - SampleMethod method, + sample_method_t method, const std::vector& sigmas, float control_strength) { size_t steps = sigmas.size() - 1; @@ -542,10 +543,10 @@ class StableDiffusionGGML { struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] if(control_hint != NULL) { - control_net.begin(noised_input, control_hint, c, t_emb); + control_net.alloc_compute_buffer(noised_input, control_hint, c, t_emb); } - diffusion_model.begin(noised_input, c, control_net.controls, t_emb, c_vector); + diffusion_model.alloc_compute_buffer(noised_input, c, control_net.controls, t_emb, c_vector); bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; @@ -1127,6 +1128,7 @@ struct sd_ctx_t { sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* vae_path_c_str, const char* taesd_path_c_str, + const char* control_net_path_c_str, const char* lora_model_dir_c_str, bool vae_decode_only, bool vae_tiling, @@ -1134,7 +1136,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, int n_threads, enum sd_type_t wtype, enum rng_type_t rng_type, - enum schedule_t s) { + enum schedule_t s, + bool keep_control_net_cpu) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1142,6 +1145,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, std::string model_path(model_path_c_str); std::string vae_path(vae_path_c_str); std::string taesd_path(taesd_path_c_str); + std::string control_net_path(control_net_path_c_str); std::string lora_model_dir(lora_model_dir_c_str); sd_ctx->sd = new StableDiffusionGGML(n_threads, @@ -1155,10 +1159,12 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, if (!sd_ctx->sd->load_from_file(model_path, vae_path, + control_net_path, taesd_path, vae_tiling, (ggml_type)wtype, - s)) { + s, + keep_control_net_cpu)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1185,7 +1191,9 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, enum sample_method_t sample_method, int sample_steps, int64_t seed, - int batch_count) { + int batch_count, + const sd_image_t* control_cond, + float control_strength) { LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1256,13 +1264,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_tensor* image_hint = NULL; if(control_cond != NULL) { image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(control_cond, image_hint); - } - - struct ggml_tensor* image_hint = NULL; - if(control_cond != NULL) { - image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(control_cond, image_hint); + sd_image_to_tensor(control_cond->data, image_hint); } std::vector final_latents; // collect latents to decode @@ -1281,7 +1283,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps); - struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas, control_strength); + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas, control_strength); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1434,8 +1436,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, - cfg_scale, sample_method, sigma_sched); + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, + uc_vector, NULL, cfg_scale, sample_method, sigma_sched, 1.0f); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index c25791e8f..e7f91e973 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -104,6 +104,7 @@ typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* vae_path, const char* taesd_path, + const char* control_net_path_c_str, const char* lora_model_dir, bool vae_decode_only, bool vae_tiling, @@ -111,7 +112,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, int n_threads, enum sd_type_t wtype, enum rng_type_t rng_type, - enum schedule_t s); + enum schedule_t s, + bool keep_control_net_cpu); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -125,7 +127,9 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, enum sample_method_t sample_method, int sample_steps, int64_t seed, - int batch_count); + int batch_count, + const sd_image_t* control_cond, + float control_strength); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, diff --git a/unet.hpp b/unet.hpp index 4210c33b7..9d9c20395 100644 --- a/unet.hpp +++ b/unet.hpp @@ -900,6 +900,8 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, + struct ggml_tensor* control_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] @@ -957,12 +959,24 @@ struct UNetModel : public GGMLModule { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + if(control.size() > 0) { + auto cs = ggml_scale_inplace(ctx0, control[control.size() - 1], control_strength); + h = ggml_add(ctx0, h, cs); // middle control + } + + int control_offset = control.size() - 2; // output_blocks for (int i = (int)len_mults - 1; i >= 0; i--) { for (int j = 0; j < num_res_blocks + 1; j++) { auto h_skip = hs.back(); hs.pop_back(); + if(control.size() > 0) { + auto cs = ggml_scale_inplace(ctx0, control[control_offset], control_strength); + h_skip = ggml_add(ctx0, h_skip, cs); // control net condition + control_offset--; + } + h = ggml_concat(ctx0, h, h_skip); h = output_res_blocks[i][j].forward(ctx0, h, emb); @@ -991,8 +1005,10 @@ struct UNetModel : public GGMLModule { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength = 1.0) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -1014,6 +1030,8 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; struct ggml_tensor* y_t = NULL; + struct ggml_tensor* control_strength = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + std::vector control_t; // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { @@ -1057,7 +1075,27 @@ struct UNetModel : public GGMLModule { y_t = y; } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); + // offload all controls tensors to gpu + if(control.size() > 0 && !ggml_backend_is_cpu(backend) && control[0]->backend != GGML_BACKEND_GPU) { + for(int i = 0; i < control.size(); i++) { + ggml_tensor* cntl_t = ggml_dup_tensor(ctx0, control[i]); + control_t.push_back(cntl_t); + ggml_allocr_alloc(compute_allocr, cntl_t); + if(!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_copy(control[i], control_t[i]); + ggml_backend_synchronize(backend); + } + } + } else { + control_t = control; + } + + ggml_allocr_alloc(compute_allocr, control_strength); + if(!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_set(control_strength, &control_net_strength, 0, sizeof(float)); + } + + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_strength, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -1067,10 +1105,12 @@ struct UNetModel : public GGMLModule { void alloc_compute_buffer(struct ggml_tensor* x, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength = 1.0) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, NULL, context, t_emb, y); + return build_graph(x, NULL, context, control, t_emb, y, control_net_strength); }; GGMLModule::alloc_compute_buffer(get_graph); } @@ -1080,10 +1120,12 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, + float control_net_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, t_emb, y); + return build_graph(x, NULL, context, control, t_emb, y, control_net_strength); }; GGMLModule::compute(get_graph, n_threads, work_latent); From 1a2775230c72e3b134c4505a3faa5c7dd5652234 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 2 Jan 2024 14:24:44 -0500 Subject: [PATCH 15/29] add Textual Inversion --- clip.hpp | 574 ++++++++++++++++++++++++------------------ examples/cli/main.cpp | 11 +- ggml_extend.hpp | 9 + stable-diffusion.cpp | 6 + stable-diffusion.h | 1 + unet.hpp | 2 +- util.cpp | 18 ++ util.h | 2 + 8 files changed, 377 insertions(+), 246 deletions(-) diff --git a/clip.hpp b/clip.hpp index f700b8086..48115cfde 100644 --- a/clip.hpp +++ b/clip.hpp @@ -2,6 +2,7 @@ #define __CLIP_HPP__ #include "ggml_extend.hpp" +#include "model.h" /*================================================== CLIPTokenizer ===================================================*/ @@ -66,242 +67,6 @@ std::vector> bytes_to_unicode() { return byte_unicode_pairs; } -// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py -class CLIPTokenizer { -private: - SDVersion version = VERSION_1_x; - std::map byte_encoder; - std::map encoder; - std::map, int> bpe_ranks; - std::regex pat; - - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - -public: - CLIPTokenizer(SDVersion version = VERSION_1_x) - : version(version) {} - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - // LOG_DEBUG("merges size %llu", merges.size()); - GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin() + 1, merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second + utf8_to_utf32("")); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|startoftext|>")); - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i++; - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - }; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size() - 1; i++) { - word.emplace_back(1, token[i]); - } - word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token + utf8_to_utf32(""); - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, size_t max_length = 0, bool padding = false) { - std::vector tokens = encode(text); - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - } else { - tokens.push_back(EOS_TOKEN_ID); - if (padding) { - int pad_token_id = PAD_TOKEN_ID; - if (version == VERSION_2_x) { - pad_token_id = 0; - } - tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); - } - } - } - return tokens; - } - - std::vector encode(std::string text) { - std::string original_text = text; - std::vector bpe_tokens; - text = whitespace_clean(text); - std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); - - std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", - std::regex::icase); - - std::smatch matches; - std::string str = text; - std::vector token_strs; - while (std::regex_search(str, matches, pat)) { - for (auto& token : matches) { - std::string token_str = token.str(); - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - str = matches.suffix(); - } - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; - } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; - // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345 // // Parses a string with attention tokens and returns a list of pairs: text and its associated weight. @@ -604,6 +369,7 @@ struct CLIPTextModel { struct ggml_tensor* position_ids; struct ggml_tensor* token_embed_weight; struct ggml_tensor* position_embed_weight; + struct ggml_tensor* token_embed_custom; // transformer std::vector resblocks; @@ -611,6 +377,9 @@ struct CLIPTextModel { struct ggml_tensor* final_ln_b; struct ggml_tensor* text_projection; + std::string embd_dir; + int32_t num_custom_embeddings = 0; + std::vector readed_embeddings; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, int clip_skip = -1, @@ -653,6 +422,9 @@ struct CLIPTextModel { mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight + if(version == OPENAI_CLIP_VIT_L_14) { + mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // token_embed_custom + } for (int i = 0; i < num_hidden_layers; i++) { mem_size += resblocks[i].calculate_mem_size(wtype); } @@ -677,14 +449,48 @@ struct CLIPTextModel { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, size_t max_token_idx = 0, bool return_pooled = false) { + bool load_embedding(std::string embd_name, std::string embd_path, std::vector &bpe_tokens) { + // the order matters + ModelLoader model_loader; + if(!model_loader.init_from_file(embd_path)) { + LOG_ERROR("embedding '%s' failed", embd_name.c_str()); + return false; + } + struct ggml_init_params params; + params.mem_size = 32 * 1024; // max for custom embeddings 32 KB + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* embd_ctx = ggml_init(params); + struct ggml_tensor* embd = NULL; + auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) { + if(tensor_storage.ne[0] != hidden_size) { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size); + return false; + } + embd = ggml_new_tensor_2d(embd_ctx, token_embed_weight->type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd; + return true; + }; + model_loader.load_tensors(on_load, NULL); + ggml_backend_tensor_set(token_embed_custom, embd->data, num_custom_embeddings * hidden_size * ggml_type_size(token_embed_custom->type), ggml_nbytes(embd)); + readed_embeddings.push_back(embd_name); + for(int i = 0; i < embd->ne[1]; i++) { + bpe_tokens.push_back(vocab_size + num_custom_embeddings); + // LOG_DEBUG("new custom token: %i", vocab_size + num_custom_embeddings); + num_custom_embeddings++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); + return true; + } + + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, uint32_t max_token_idx = 0, bool return_pooled = false) { // input_ids: [N, n_token] GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]); // token_embedding + position_embedding struct ggml_tensor* x; x = ggml_add(ctx0, - ggml_get_rows(ctx0, token_embed_weight, input_ids), + ggml_get_rows(ctx0, tkn_embeddings == NULL ? token_embed_weight : tkn_embeddings, input_ids), ggml_get_rows(ctx0, position_embed_weight, ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] @@ -730,6 +536,10 @@ struct CLIPTextModel { final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + if(version == OPENAI_CLIP_VIT_L_14) { + token_embed_custom = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings); + } + if (version == OPEN_CLIP_VIT_BIGG_14) { text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); } @@ -755,6 +565,263 @@ struct CLIPTextModel { } }; + +// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py +class CLIPTokenizer { +private: + SDVersion version = VERSION_1_x; + std::map byte_encoder; + std::map encoder; + std::map, int> bpe_ranks; + std::regex pat; + + static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + // String contains only whitespace characters + return ""; + } + + return str.substr(start, end - start + 1); + } + + static std::string whitespace_clean(std::string text) { + text = std::regex_replace(text, std::regex(R"(\s+)"), " "); + text = strip(text); + return text; + } + + static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.size() == 0) { + return pairs; + } + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < subwords.size(); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; + } + +public: + CLIPTokenizer(SDVersion version = VERSION_1_x) + : version(version) {} + + void load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + // for (auto & pair: byte_unicode_pairs) { + // std::cout << pair.first << ": " << pair.second << std::endl; + // } + std::vector merges; + size_t start = 0; + size_t pos; + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { + merges.push_back(merges_utf32_str.substr(start, pos - start)); + start = pos + 1; + } + // LOG_DEBUG("merges size %llu", merges.size()); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + } + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second + utf8_to_utf32("")); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + vocab.push_back(utf8_to_utf32("<|startoftext|>")); + vocab.push_back(utf8_to_utf32("<|endoftext|>")); + LOG_DEBUG("vocab size: %llu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i++; + } + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + }; + + std::u32string bpe(const std::u32string& token) { + std::vector word; + + for (int i = 0; i < token.size() - 1; i++) { + word.emplace_back(1, token[i]); + } + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return token + utf8_to_utf32(""); + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + std::u32string result; + for (int i = 0; i < word.size(); i++) { + result += word[i]; + if (i != word.size() - 1) { + result += utf8_to_utf32(" "); + } + } + + return result; + } + + std::vector tokenize(std::string text, CLIPTextModel& text_model, size_t max_length = 0, bool padding = false) { + std::vector tokens = encode(text, text_model); + tokens.insert(tokens.begin(), BOS_TOKEN_ID); + if (max_length > 0) { + if (tokens.size() > max_length - 1) { + tokens.resize(max_length - 1); + tokens.push_back(EOS_TOKEN_ID); + } else { + tokens.push_back(EOS_TOKEN_ID); + if (padding) { + int pad_token_id = PAD_TOKEN_ID; + if (version == VERSION_2_x) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); + } + } + } + return tokens; + } + + std::vector encode(std::string text, CLIPTextModel& text_model) { + std::string original_text = text; + std::vector bpe_tokens; + text = whitespace_clean(text); + std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); + + std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", + std::regex::icase); + + std::smatch matches; + std::string str = text; + std::vector token_strs; + while (std::regex_search(str, matches, pat)) { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = path_join(text_model.embd_dir, embd_name + ".pt"); + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".ckpt"); + } + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".safetensors"); + } + if(file_exists(embd_path)) { + if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { + if(word_end != std::string::npos) { + str = str.substr(word_end); + } else { + str = ""; + } + continue; + } + } + for (auto& token : matches) { + std::string token_str = token.str(); + std::u32string utf32_token; + for (int i = 0; i < token_str.length(); i++) { + char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + size_t start = 0; + size_t pos; + while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { + auto bpe_str = bpe_strs.substr(start, pos - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + + start = pos + 1; + } + auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + str = matches.suffix(); + } + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + return bpe_tokens; + } +}; + // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { @@ -812,11 +879,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, size_t max_token_idx = 0, bool return_pooled = false) { + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, struct ggml_tensor* embeddings, size_t max_token_idx = 0, bool return_pooled = false) { if (return_pooled) { - return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled); + return text_model2.forward(ctx0, input_ids2, NULL, max_token_idx, return_pooled); } - auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size] + auto hidden_states = text_model.forward(ctx0, input_ids, embeddings); // [N, n_token, hidden_size] // LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); if (version == VERSION_XL) { hidden_states = ggml_reshape_4d(ctx0, @@ -827,7 +894,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { hidden_states->ne[3]); hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3)); - auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2] + auto hidden_states2 = text_model2.forward(ctx0, input_ids2, NULL); // [N, n_token, hidden_size2] hidden_states2 = ggml_reshape_4d(ctx0, hidden_states2, hidden_states2->ne[0], @@ -869,7 +936,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.encode(curr_text); + std::vector curr_tokens = tokenizer.encode(curr_text, text_model); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } @@ -958,7 +1025,26 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { } } - struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled); + struct ggml_tensor* embeddings = NULL; + + if(text_model.num_custom_embeddings > 0 && version != VERSION_XL) { + embeddings = ggml_new_tensor_2d(ctx0, wtype, text_model.hidden_size, text_model.vocab_size + text_model.num_custom_embeddings /* custom placeholder */); + ggml_allocr_alloc(allocr, embeddings); + if (!ggml_allocr_is_measure(allocr)) { + // really bad, there is memory inflexibility (this is for host<->device memory conflicts) + void* freeze_data = malloc(ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_weight, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_set_and_sync(backend, embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + free(freeze_data); + // concatenate custom embeddings + void* custom_data = malloc(ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_custom, custom_data, 0, ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_set_and_sync(backend, embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)); + free(custom_data); + } + } + + struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, embeddings, max_token_idx, return_pooled); ggml_build_forward_expand(gf, hidden_states); ggml_free(ctx0); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index e7311fc5f..78f9cb454 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -59,6 +59,7 @@ struct SDParams { std::string taesd_path; std::string esrgan_path; std::string controlnet_path; + std::string embeddings_path; sd_type_t wtype = SD_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; @@ -136,6 +137,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); @@ -225,7 +227,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.esrgan_path = argv[i]; - } else if (arg == "--type") { + } else if (arg == "--embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.embeddings_path = argv[i]; + } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; break; @@ -540,6 +548,7 @@ int main(int argc, const char* argv[]) { params.taesd_path.c_str(), params.controlnet_path.c_str(), params.lora_model_dir.c_str(), + params.embeddings_path.c_str(), vae_decode_only, params.vae_tiling, true, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 480ea77ca..545e22645 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -460,6 +460,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct return x; } +__STATIC_INLINE__ void ggml_backend_tensor_set_and_sync(ggml_backend_t backend, struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { +#ifdef SD_USE_CUBLAS + ggml_backend_tensor_set_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); +#else + ggml_backend_tensor_set(tensor, data, offset, size); +#endif +} + __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { #ifdef SD_USE_CUBLAS ggml_backend_tensor_get_async(backend, tensor, data, offset, size); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 24fa1eae9..5263d7203 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -110,6 +110,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, const std::string control_net_path, + const std::string embeddings_path, const std::string& taesd_path, bool vae_tiling, ggml_type wtype, @@ -188,6 +189,8 @@ class StableDiffusionGGML { return false; } + cond_stage_model.text_model.embd_dir = embeddings_path; + ggml_type vae_type = model_data_type; if (version == VERSION_XL) { vae_type = GGML_TYPE_F32; // avoid nan, not work... @@ -1130,6 +1133,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* taesd_path_c_str, const char* control_net_path_c_str, const char* lora_model_dir_c_str, + const char* embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, @@ -1146,6 +1150,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, std::string vae_path(vae_path_c_str); std::string taesd_path(taesd_path_c_str); std::string control_net_path(control_net_path_c_str); + std::string embd_path(embed_dir_c_str); std::string lora_model_dir(lora_model_dir_c_str); sd_ctx->sd = new StableDiffusionGGML(n_threads, @@ -1160,6 +1165,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, if (!sd_ctx->sd->load_from_file(model_path, vae_path, control_net_path, + embd_path, taesd_path, vae_tiling, (ggml_type)wtype, diff --git a/stable-diffusion.h b/stable-diffusion.h index e7f91e973..f4699b788 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -106,6 +106,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* taesd_path, const char* control_net_path_c_str, const char* lora_model_dir, + const char* embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, diff --git a/unet.hpp b/unet.hpp index 9d9c20395..54a011571 100644 --- a/unet.hpp +++ b/unet.hpp @@ -1125,7 +1125,7 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, NULL, context, control, t_emb, y, control_net_strength); + return build_graph(x, timesteps, context, control, t_emb, y, control_net_strength); }; GGMLModule::compute(get_graph, n_threads, work_latent); diff --git a/util.cpp b/util.cpp index 4057d13d0..85e81d6f2 100644 --- a/util.cpp +++ b/util.cpp @@ -192,6 +192,24 @@ void pretty_progress(int step, int steps, float time) { } } +std::string ltrim(const std::string& s) { + auto it = std::find_if(s.begin(), s.end(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(it, s.end()); +} + +std::string rtrim(const std::string& s) { + auto it = std::find_if(s.rbegin(), s.rend(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(s.begin(), it.base()); +} + +std::string trim(const std::string& s) { + return rtrim(ltrim(s)); +} + static sd_log_cb_t sd_log_cb = NULL; void* sd_log_cb_data = NULL; diff --git a/util.h b/util.h index 3a611655d..ca0830e12 100644 --- a/util.h +++ b/util.h @@ -28,6 +28,8 @@ void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); +std::string trim(const std::string& s); + #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) From af8cf682430245042d12f15ed326fe14a8976c06 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 2 Jan 2024 14:32:35 -0500 Subject: [PATCH 16/29] fix cuda sync issues + fix ci errors --- README.md | 9 ++++++--- clip.hpp | 4 ++-- control.hpp | 10 +++++----- esrgan.hpp | 2 +- lora.hpp | 2 +- tae.hpp | 2 +- unet.hpp | 12 ++++++------ util.cpp | 2 +- vae.hpp | 2 +- 9 files changed, 24 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 159f4277d..5e899e628 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - [ ] More sampling methods - [ ] Make inference faster - The current implementation of ggml_conv_2d is slow and has high memory usage - - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement Textual Inversion (embeddings) - [ ] Implement Inpainting support - [ ] k-quants support @@ -149,16 +147,20 @@ arguments: -m, --model [MODEL] path to model --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + --control-net [CONTROL_PATH] path to control net model + --embd-dir [EMBEDDING_PATH] path to embeddings. --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img + --control-image [IMAGE] path to image condition, control net -o, --output OUTPUT path to write result image to (default: ./output.png) -p, --prompt [PROMPT] the prompt to render -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) --strength STRENGTH strength for noising/unnoising (default: 0.75) + --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) -W, --width W image width, in pixel space (default: 512) @@ -169,8 +171,9 @@ arguments: -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate. --schedule {discrete, karras} Denoiser sigma schedule (default: discrete) - --clip-skip N number of layers to skip of clip model (default: 0) + --clip-skip N number of layers to skip of clip model (default: -1) --vae-tiling process vae in tiles to reduce memory usage + --control-net-cpu keep controlnet in cpu (for low vram) -v, --verbose print extra info ``` diff --git a/clip.hpp b/clip.hpp index 48115cfde..64bb7c670 100644 --- a/clip.hpp +++ b/clip.hpp @@ -999,7 +999,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { ggml_allocr_alloc(allocr, input_ids); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); + ggml_backend_tensor_set_and_sync(backend, input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); } struct ggml_tensor* input_ids2 = NULL; @@ -1021,7 +1021,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { // printf("\n"); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); + ggml_backend_tensor_set_and_sync(backend, input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); } } diff --git a/control.hpp b/control.hpp index c15de901d..cf8d88be5 100644 --- a/control.hpp +++ b/control.hpp @@ -571,14 +571,14 @@ struct ControlNet : public GGMLModule { } // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); - ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); - ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); + ggml_backend_tensor_set_and_sync(backend, x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set_and_sync(backend, context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set_and_sync(backend, hint_t, hint->data, 0, ggml_nbytes(hint)); if (timesteps_t != NULL) { - ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + ggml_backend_tensor_set_and_sync(backend, timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { - ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + ggml_backend_tensor_set_and_sync(backend, t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } } } else { diff --git a/esrgan.hpp b/esrgan.hpp index a7c8ac79d..9b9b04abc 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -391,7 +391,7 @@ struct ESRGAN : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set_and_sync(backend, x_, x->data, 0, ggml_nbytes(x)); } } else { x_ = x; diff --git a/lora.hpp b/lora.hpp index 7f22136ab..206a868c9 100644 --- a/lora.hpp +++ b/lora.hpp @@ -127,7 +127,7 @@ struct LoraModel : public GGMLModule { ggml_allocr_alloc(compute_allocr, lora_scale); if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(lora_scale, &scale_value, 0, ggml_nbytes(lora_scale)); + ggml_backend_tensor_set_and_sync(backend, lora_scale, &scale_value, 0, ggml_nbytes(lora_scale)); } // flat lora tensors to multiply it diff --git a/tae.hpp b/tae.hpp index 422fd78f8..868b7736c 100644 --- a/tae.hpp +++ b/tae.hpp @@ -562,7 +562,7 @@ struct TinyAutoEncoder : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; diff --git a/unet.hpp b/unet.hpp index 54a011571..1e5edee11 100644 --- a/unet.hpp +++ b/unet.hpp @@ -1054,16 +1054,16 @@ struct UNetModel : public GGMLModule { } // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); - ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set_and_sync(backend, x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set_and_sync(backend, context_t, context->data, 0, ggml_nbytes(context)); if (timesteps_t != NULL) { - ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + ggml_backend_tensor_set_and_sync(backend, timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { - ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + ggml_backend_tensor_set_and_sync(backend, t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } if (y != NULL) { - ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); + ggml_backend_tensor_set_and_sync(backend, y_t, y->data, 0, ggml_nbytes(y)); } } } else { @@ -1092,7 +1092,7 @@ struct UNetModel : public GGMLModule { ggml_allocr_alloc(compute_allocr, control_strength); if(!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(control_strength, &control_net_strength, 0, sizeof(float)); + ggml_backend_tensor_set_and_sync(backend, control_strength, &control_net_strength, 0, sizeof(float)); } struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_strength, t_emb_t, y_t); diff --git a/util.cpp b/util.cpp index 85e81d6f2..e8b9e4250 100644 --- a/util.cpp +++ b/util.cpp @@ -1,5 +1,5 @@ #include "util.h" - +#include #include #include #include diff --git a/vae.hpp b/vae.hpp index 478d9efe4..1eccf78af 100644 --- a/vae.hpp +++ b/vae.hpp @@ -714,7 +714,7 @@ struct AutoEncoderKL : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; From 16ccedd72ef12583f7e6550ae6b7d733ecc71b5a Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 2 Jan 2024 14:42:14 -0500 Subject: [PATCH 17/29] add cli print options --- examples/cli/main.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 78f9cb454..4b1db8be6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -107,8 +107,13 @@ void print_params(SDParams params) { printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); + printf(" controlnet_path: %s\n", params.controlnet_path.c_str()); + printf(" embeddings_path: %s\n", params.embeddings_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); + printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); + printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" cfg_scale: %.2f\n", params.cfg_scale); From 92d9c7ca3fccd2d3987aa2a875c5d945883b90be Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 3 Jan 2024 22:44:49 -0500 Subject: [PATCH 18/29] add canny preprocessor --- examples/cli/main.cpp | 13 ++- ggml_extend.hpp | 2 +- preprocessing.hpp | 229 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 preprocessing.hpp diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 4b1db8be6..d9e74ed02 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -7,6 +7,7 @@ #include #include "stable-diffusion.h" +#include "preprocessing.hpp" #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -83,7 +84,8 @@ struct SDParams { int64_t seed = 42; bool verbose = false; bool vae_tiling = false; - bool control_net_cpu = false; + bool control_net_cpu = false; + bool canny_preprocess = false; }; static std::string sd_basename(const std::string& path) { @@ -169,6 +171,7 @@ void print_usage(int argc, const char* argv[]) { printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); + printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" -v, --verbose print extra info\n"); } @@ -345,6 +348,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.vae_tiling = true; } else if (arg == "--control-net-cpu") { params.control_net_cpu = true; + } else if (arg == "--canny") { + params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -578,10 +583,14 @@ int main(int argc, const char* argv[]) { fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); return 1; } - control_image = new sd_image_t{(uint32_t)params.width, + control_image = new sd_image_t{(uint32_t)params.width, (uint32_t)params.height, 3, input_image_buffer}; + if(params.canny_preprocess) { // apply preprocessor + LOG_INFO("Applying canny preprocessor"); + control_image->data = preprocess_canny(control_image->data, control_image->width, control_image->height); + } } results = txt2img(sd_ctx, params.prompt.c_str(), diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 545e22645..d21353545 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -219,7 +219,7 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = *(image_data + iy * width * channels + ix * channels + k); + int value = *(image_data + iy * width * channels + ix * channels + k); ggml_tensor_set_f32(output, value / 255.0f, ix, iy, k); } } diff --git a/preprocessing.hpp b/preprocessing.hpp new file mode 100644 index 000000000..d5bbd5647 --- /dev/null +++ b/preprocessing.hpp @@ -0,0 +1,229 @@ +#ifndef __PREPROCESSING_HPP__ +#define __PREPROCESSING_HPP__ + +#include "ggml_extend.hpp" +#define M_PI_ 3.14159265358979323846 + +void convolve(struct ggml_tensor* input, struct ggml_tensor* output, struct ggml_tensor* kernel, int padding) { + struct ggml_init_params params; + params.mem_size = 20 * 1024 * 1024; // 10 + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* ctx0 = ggml_init(params); + struct ggml_tensor* kernel_fp16 = ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, kernel->ne[0], kernel->ne[1], 1, 1); + ggml_fp32_to_fp16_row((float*)kernel->data, (ggml_fp16_t*) kernel_fp16->data, ggml_nelements(kernel)); + ggml_tensor* h = ggml_conv_2d(ctx0, kernel_fp16, input, 1, 1, padding, padding, 1, 1); + ggml_cgraph* gf = ggml_new_graph(ctx0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h, output)); + ggml_graph_compute_with_ctx(ctx0, gf, 1); + ggml_free(ctx0); +} + +void gaussian_kernel(struct ggml_tensor* kernel) { + int ks_mid = kernel->ne[0] / 2; + float sigma = 1.4f; + float normal = 1.f / (2.0f * M_PI_ * powf(sigma, 2.0f)); + for(int y = 0; y < kernel->ne[0]; y++) { + float gx = -ks_mid + y; + for(int x = 0; x < kernel->ne[1]; x++) { + float gy = -ks_mid + x; + float k_ = expf(-((gx*gx + gy*gy) / (2.0f * powf(sigma, 2.0f)))) * normal; + ggml_tensor_set_f32(kernel, k_, x, y); + } + } +} + +void grayscale(struct ggml_tensor* rgb_img, struct ggml_tensor* grayscale) { + for (int iy = 0; iy < rgb_img->ne[1]; iy++) { + for (int ix = 0; ix < rgb_img->ne[0]; ix++) { + float r = ggml_tensor_get_f32(rgb_img, ix, iy); + float g = ggml_tensor_get_f32(rgb_img, ix, iy, 1); + float b = ggml_tensor_get_f32(rgb_img, ix, iy, 2); + float gray = 0.2989f * r + 0.5870f * g + 0.1140f * b; + ggml_tensor_set_f32(grayscale, gray, ix, iy); + } + } +} + +void prop_hypot(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) { + int n_elements = ggml_nelements(h); + float* dx = (float*)x->data; + float* dy = (float*)y->data; + float* dh = (float*)h->data; + for (int i = 0; i data; + float* dy = (float*)y->data; + float* dh = (float*)h->data; + for (int i = 0; i < n_elements; i++) { + dh[i] = atan2f(dy[i], dx[i]); + } +} + +void normalize_tensor(struct ggml_tensor* g) { + int n_elements = ggml_nelements(g); + float* dg = (float*)g->data; + float max = -INFINITY; + for (int i = 0; i max ? dg[i] : max; + } + max = 1.0f / max; + for (int i = 0; i ne[1] - 1; iy++) { + for (int ix = 1; ix < result->ne[0] - 1; ix++) { + float angle = ggml_tensor_get_f32(D, ix, iy) * 180.0f / M_PI_; + angle = angle < 0.0f ? angle += 180.0f : angle; + float q = 1.0f; + float r = 1.0f; + + // angle 0 + if((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180)){ + q = ggml_tensor_get_f32(G, ix, iy + 1); + r = ggml_tensor_get_f32(G, ix, iy - 1); + } + // angle 45 + else if (22.5f >= angle && angle < 67.5f) { + q = ggml_tensor_get_f32(G, ix + 1, iy - 1); + r = ggml_tensor_get_f32(G, ix - 1, iy + 1); + } + // angle 90 + else if (67.5f >= angle && angle < 112.5) { + q = ggml_tensor_get_f32(G, ix + 1, iy); + r = ggml_tensor_get_f32(G, ix - 1, iy); + } + // angle 135 + else if (112.5 >= angle && angle < 157.5f) { + q = ggml_tensor_get_f32(G, ix - 1, iy - 1); + r = ggml_tensor_get_f32(G, ix + 1, iy + 1); + } + + float cur = ggml_tensor_get_f32(G, ix, iy); + if ((cur >= q) && (cur >= r)) { + ggml_tensor_set_f32(result, cur, ix, iy); + } else { + ggml_tensor_set_f32(result, 0.0f, ix, iy); + } + } + } +} + +void threshold_hystersis(struct ggml_tensor* img, float highThreshold, float lowThreshold, float weak, float strong) { + int n_elements = ggml_nelements(img); + float* imd = (float*)img->data; + float max = -INFINITY; + for (int i = 0; i < n_elements; i++) { + max = imd[i] > max ? imd[i] : max; + } + float ht = max * highThreshold; + float lt = ht * lowThreshold; + for (int i = 0; i < n_elements; i++) { + float img_v = imd[i]; + if(img_v >= ht) { // strong pixel + imd[i] = strong; + } else if(img_v <= ht && img_v >= lt) { // strong pixel + imd[i] = weak; + } + } + + for (int iy = 0; iy < img->ne[1]; iy++) { + for (int ix = 0; ix < img->ne[0]; ix++) { + if(ix >= 3 && ix <= img->ne[0] - 3 && iy >= 3 && iy <= img->ne[1] - 3) { + ggml_tensor_set_f32(img, ggml_tensor_get_f32(img, ix, iy), ix, iy); + } else { + ggml_tensor_set_f32(img, 0.0f, ix, iy); + } + } + } + + // hysteresis + for (int iy = 1; iy < img->ne[1] - 1; iy++) { + for (int ix = 1; ix < img->ne[0] - 1; ix++) { + float imd_v = ggml_tensor_get_f32(img, ix, iy); + if(imd_v == weak) { + if(ggml_tensor_get_f32(img, ix + 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix + 1, iy) == strong || + ggml_tensor_get_f32(img, ix, iy - 1) == strong || ggml_tensor_get_f32(img, ix, iy + 1) == strong || + ggml_tensor_get_f32(img, ix - 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix - 1, iy) == strong) { + ggml_tensor_set_f32(img, strong, ix, iy); + } else { + ggml_tensor_set_f32(img, 0.0f, ix, iy); + } + } + } + } +} + +uint8_t* preprocess_canny(uint8_t* img, int width, int height, float highThreshold = 0.08f, float lowThreshold = 0.08f, float weak = 0.8f, float strong = 1.0f, bool inverse = false) { + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024 * 1024); // 10 + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* work_ctx = ggml_init(params); + + if (!work_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + + float kX[9] = { + -1, 0, 1, + -2, 0, 2, + -1, 0, 1 + }; + + float kY[9] = { + 1, 2, 1, + 0, 0, 0, + -1, -2, -1 + }; + + // generate kernel + int kernel_size = 5; + struct ggml_tensor* gkernel = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, kernel_size, kernel_size, 1, 1); + struct ggml_tensor* sf_kx = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1); + memcpy(sf_kx->data, kX, ggml_nbytes(sf_kx)); + struct ggml_tensor* sf_ky = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1); + memcpy(sf_ky->data, kY, ggml_nbytes(sf_ky)); + gaussian_kernel(gkernel); + struct ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + struct ggml_tensor* image_gray = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + struct ggml_tensor* iX = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* tetha = ggml_dup_tensor(work_ctx, image_gray); + sd_image_to_tensor(img, image); + grayscale(image, image_gray); + convolve(image_gray, image_gray, gkernel, 2); + convolve(image_gray, iX, sf_kx, 1); + convolve(image_gray, iY, sf_ky, 1); + prop_hypot(iX, iY, G); + normalize_tensor(G); + prop_arctan2(iX, iY, tetha); + non_max_supression(image_gray, G, tetha); + threshold_hystersis(image_gray, highThreshold, lowThreshold, weak, strong); + // to RGB channels + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + float gray = ggml_tensor_get_f32(image_gray, ix, iy); + gray = inverse ? 1.0f - gray : gray; + ggml_tensor_set_f32(image, gray, ix, iy); + ggml_tensor_set_f32(image, gray, ix, iy, 1); + ggml_tensor_set_f32(image, gray, ix, iy, 2); + } + } + free(img); + uint8_t* output = sd_tensor_to_image(image); + ggml_free(work_ctx); + return output; +} + +#endif // __PREPROCESSING_HPP__ \ No newline at end of file From 429aebfe78a5e24c1f88bf798f568121e125f8bf Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 09:41:27 -0500 Subject: [PATCH 19/29] fix ci error --- examples/cli/main.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index d9e74ed02..33144a6d4 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -88,18 +88,6 @@ struct SDParams { bool canny_preprocess = false; }; -static std::string sd_basename(const std::string& path) { - size_t pos = path.find_last_of('/'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - pos = path.find_last_of('\\'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - return path; -} - void print_params(SDParams params) { printf("Option: \n"); printf(" n_threads: %d\n", params.n_threads); From 9bbee7e74c77722e500e003e91b002c099dce84f Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 10:48:06 -0500 Subject: [PATCH 20/29] move resblock and spatial-T to common.hpp to reuse in control.hpp --- common.hpp | 457 ++++++++++++++++++++++++++++++++++++++++++++++++++ control.hpp | 8 +- unet.hpp | 472 +--------------------------------------------------- 3 files changed, 467 insertions(+), 470 deletions(-) diff --git a/common.hpp b/common.hpp index 458f8be4a..37692d58d 100644 --- a/common.hpp +++ b/common.hpp @@ -83,4 +83,461 @@ struct UpSample { } }; +struct ResBlock { + // network hparams + int channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4) + int emb_channels; // time_embed_dim + int out_channels; // mult * model_channels + + // network params + // in_layers + struct ggml_tensor* in_layer_0_w; // [channels, ] + struct ggml_tensor* in_layer_0_b; // [channels, ] + // in_layer_1 is nn.SILU() + struct ggml_tensor* in_layer_2_w; // [out_channels, channels, 3, 3] + struct ggml_tensor* in_layer_2_b; // [out_channels, ] + + // emb_layers + // emb_layer_0 is nn.SILU() + struct ggml_tensor* emb_layer_1_w; // [out_channels, emb_channels] + struct ggml_tensor* emb_layer_1_b; // [out_channels, ] + + // out_layers + struct ggml_tensor* out_layer_0_w; // [out_channels, ] + struct ggml_tensor* out_layer_0_b; // [out_channels, ] + // out_layer_1 is nn.SILU() + // out_layer_2 is nn.Dropout(), p = 0 for inference + struct ggml_tensor* out_layer_3_w; // [out_channels, out_channels, 3, 3] + struct ggml_tensor* out_layer_3_b; // [out_channels, ] + + // skip connection, only if out_channels != channels + struct ggml_tensor* skip_w; // [out_channels, channels, 1, 1] + struct ggml_tensor* skip_b; // [out_channels, ] + + size_t calculate_mem_size(ggml_type wtype) { + double mem_size = 0; + mem_size += 2 * channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_0_w/b + mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // in_layer_2_w + mem_size += 5 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b + mem_size += out_channels * emb_channels * ggml_type_sizef(wtype); // emb_layer_1_w + mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_layer_3_w + + if (out_channels != channels) { + mem_size += out_channels * channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // skip_w + mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // skip_b + } + return static_cast(mem_size); + } + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + in_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + in_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + in_layer_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels); + in_layer_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + emb_layer_1_w = ggml_new_tensor_2d(ctx, wtype, emb_channels, out_channels); + emb_layer_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + out_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + out_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + out_layer_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels); + out_layer_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + if (out_channels != channels) { + skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels, out_channels); + skip_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + } + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "in_layers.0.weight"] = in_layer_0_w; + tensors[prefix + "in_layers.0.bias"] = in_layer_0_b; + tensors[prefix + "in_layers.2.weight"] = in_layer_2_w; + tensors[prefix + "in_layers.2.bias"] = in_layer_2_b; + + tensors[prefix + "emb_layers.1.weight"] = emb_layer_1_w; + tensors[prefix + "emb_layers.1.bias"] = emb_layer_1_b; + + tensors[prefix + "out_layers.0.weight"] = out_layer_0_w; + tensors[prefix + "out_layers.0.bias"] = out_layer_0_b; + tensors[prefix + "out_layers.3.weight"] = out_layer_3_w; + tensors[prefix + "out_layers.3.bias"] = out_layer_3_b; + + if (out_channels != channels) { + tensors[prefix + "skip_connection.weight"] = skip_w; + tensors[prefix + "skip_connection.bias"] = skip_b; + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) { + // x: [N, channels, h, w] + // emb: [N, emb_channels] + + // in_layers + auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b); + h = ggml_silu_inplace(ctx, h); + h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] + + // emb_layers + auto emb_out = ggml_silu(ctx, emb); + emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels] + emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] + + // out_layers + h = ggml_add(ctx, h, emb_out); + h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b); + h = ggml_silu_inplace(ctx, h); + + // dropout, skip for inference + + h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w] + + // skip connection + if (out_channels != channels) { + x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w] + } + + h = ggml_add(ctx, h, x); + return h; // [N, out_channels, h, w] + } +}; + +struct SpatialTransformer { + int in_channels; // mult * model_channels + int n_head; // num_heads + int d_head; // in_channels // n_heads + int depth = 1; // 1 + int context_dim = 768; // hidden_size, 1024 for VERSION_2_x + + // group norm + struct ggml_tensor* norm_w; // [in_channels,] + struct ggml_tensor* norm_b; // [in_channels,] + + // proj_in + struct ggml_tensor* proj_in_w; // [in_channels, in_channels, 1, 1] + struct ggml_tensor* proj_in_b; // [in_channels,] + + // transformer + struct Transformer { + // layer norm 1 + struct ggml_tensor* norm1_w; // [in_channels, ] + struct ggml_tensor* norm1_b; // [in_channels, ] + + // attn1 + struct ggml_tensor* attn1_q_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_k_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_v_w; // [in_channels, in_channels] + + struct ggml_tensor* attn1_out_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_out_b; // [in_channels, ] + + // layer norm 2 + struct ggml_tensor* norm2_w; // [in_channels, ] + struct ggml_tensor* norm2_b; // [in_channels, ] + + // attn2 + struct ggml_tensor* attn2_q_w; // [in_channels, in_channels] + struct ggml_tensor* attn2_k_w; // [in_channels, context_dim] + struct ggml_tensor* attn2_v_w; // [in_channels, context_dim] + + struct ggml_tensor* attn2_out_w; // [in_channels, in_channels] + struct ggml_tensor* attn2_out_b; // [in_channels, ] + + // layer norm 3 + struct ggml_tensor* norm3_w; // [in_channels, ] + struct ggml_tensor* norm3_b; // [in_channels, ] + + // ff + struct ggml_tensor* ff_0_proj_w; // [in_channels * 4 * 2, in_channels] + struct ggml_tensor* ff_0_proj_b; // [in_channels * 4 * 2] + + struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4] + struct ggml_tensor* ff_2_b; // [in_channels,] + }; + + std::vector transformers; + + // proj_out + struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1] + struct ggml_tensor* proj_out_b; // [in_channels,] + + SpatialTransformer(int depth = 1) + : depth(depth) { + transformers.resize(depth); + } + + int get_num_tensors() { + return depth * 20 + 7; + } + + size_t calculate_mem_size(ggml_type wtype) { + double mem_size = 0; + mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b + mem_size += 2 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // proj_in_w/proj_out_w + mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // proj_in_b/proj_out_b + + // transformer + for (auto& transformer : transformers) { + mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1-3_w/b + mem_size += 6 * in_channels * in_channels * ggml_type_sizef(wtype); // attn1_q/k/v/out_w attn2_q/out_w + mem_size += 2 * in_channels * context_dim * ggml_type_sizef(wtype); // attn2_k/v_w + mem_size += in_channels * 4 * 2 * in_channels * ggml_type_sizef(wtype); // ff_0_proj_w + mem_size += in_channels * 4 * 2 * ggml_type_sizef(GGML_TYPE_F32); // ff_0_proj_b + mem_size += in_channels * 4 * in_channels * ggml_type_sizef(wtype); // ff_2_w + mem_size += in_channels * ggml_type_sizef(GGML_TYPE_F32); // ff_2_b + } + return static_cast(mem_size); + } + + void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { + norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + proj_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); + proj_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); + proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + // transformer + for (auto& transformer : transformers) { + transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + + transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + + transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); + transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); + + transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); + transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + } + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "norm.weight"] = norm_w; + tensors[prefix + "norm.bias"] = norm_b; + tensors[prefix + "proj_in.weight"] = proj_in_w; + tensors[prefix + "proj_in.bias"] = proj_in_b; + + // transformer + for (int i = 0; i < transformers.size(); i++) { + auto& transformer = transformers[i]; + std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + "."; + tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w; + tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w; + tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w; + + tensors[transformer_prefix + "attn1.to_out.0.weight"] = transformer.attn1_out_w; + tensors[transformer_prefix + "attn1.to_out.0.bias"] = transformer.attn1_out_b; + + tensors[transformer_prefix + "ff.net.0.proj.weight"] = transformer.ff_0_proj_w; + tensors[transformer_prefix + "ff.net.0.proj.bias"] = transformer.ff_0_proj_b; + tensors[transformer_prefix + "ff.net.2.weight"] = transformer.ff_2_w; + tensors[transformer_prefix + "ff.net.2.bias"] = transformer.ff_2_b; + + tensors[transformer_prefix + "attn2.to_q.weight"] = transformer.attn2_q_w; + tensors[transformer_prefix + "attn2.to_k.weight"] = transformer.attn2_k_w; + tensors[transformer_prefix + "attn2.to_v.weight"] = transformer.attn2_v_w; + + tensors[transformer_prefix + "attn2.to_out.0.weight"] = transformer.attn2_out_w; + tensors[transformer_prefix + "attn2.to_out.0.bias"] = transformer.attn2_out_b; + + tensors[transformer_prefix + "norm1.weight"] = transformer.norm1_w; + tensors[transformer_prefix + "norm1.bias"] = transformer.norm1_b; + tensors[transformer_prefix + "norm2.weight"] = transformer.norm2_w; + tensors[transformer_prefix + "norm2.bias"] = transformer.norm2_b; + tensors[transformer_prefix + "norm3.weight"] = transformer.norm3_w; + tensors[transformer_prefix + "norm3.bias"] = transformer.norm3_b; + } + + tensors[prefix + "proj_out.weight"] = proj_out_w; + tensors[prefix + "proj_out.bias"] = proj_out_b; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { + // x: [N, in_channels, h, w] + // context: [N, max_position, hidden_size(aka context_dim)] + auto x_in = x; + x = ggml_nn_group_norm(ctx, x, norm_w, norm_b); + // proj_in + x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w] + + // transformer + const int64_t n = x->ne[3]; + const int64_t c = x->ne[2]; + const int64_t h = x->ne[1]; + const int64_t w = x->ne[0]; + const int64_t max_position = context->ne[1]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels] + + for (auto& transformer : transformers) { + auto r = x; + // layer norm 1 + x = ggml_reshape_2d(ctx, x, c, w * h * n); + x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b); + + // self-attention + { + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); +#endif + q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn1_k_w, x); // [N * h * w, in_channels] + k = ggml_reshape_4d(ctx, k, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + k = ggml_reshape_3d(ctx, k, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn1_v_w, x); // [N * h * w, in_channels] + v = ggml_reshape_4d(ctx, v, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] + v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] + +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) + struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] +#else + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] + // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + kq = ggml_soft_max_inplace(ctx, kq); + + struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] +#endif + kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); + kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, h * w, n_head, d_head] + + // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); + x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); + + x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b); + + x = ggml_reshape_4d(ctx, x, c, w, h, n); + } + + x = ggml_add(ctx, x, r); + r = x; + + // layer norm 2 + x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b); + + // cross-attention + { + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] + struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); +#endif + q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn2_k_w, context); // [N * max_position, in_channels] + k = ggml_reshape_4d(ctx, k, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, max_position, d_head] + k = ggml_reshape_3d(ctx, k, d_head, max_position, n_head * n); // [N * n_head, max_position, d_head] + + struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn2_v_w, context); // [N * max_position, in_channels] + v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] + v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) + struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] +#else + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] + // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + kq = ggml_soft_max_inplace(ctx, kq); + + struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] +#endif + kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); + kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); + + // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels] + x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels] + + x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b); + + x = ggml_reshape_4d(ctx, x, c, w, h, n); + } + + x = ggml_add(ctx, x, r); + r = x; + + // layer norm 3 + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b); + + // ff + { + // GEGLU + auto x_w = ggml_view_2d(ctx, + transformer.ff_0_proj_w, + transformer.ff_0_proj_w->ne[0], + transformer.ff_0_proj_w->ne[1] / 2, + transformer.ff_0_proj_w->nb[1], + 0); // [in_channels * 4, in_channels] + auto x_b = ggml_view_1d(ctx, + transformer.ff_0_proj_b, + transformer.ff_0_proj_b->ne[0] / 2, + 0); // [in_channels * 4, in_channels] + auto gate_w = ggml_view_2d(ctx, + transformer.ff_0_proj_w, + transformer.ff_0_proj_w->ne[0], + transformer.ff_0_proj_w->ne[1] / 2, + transformer.ff_0_proj_w->nb[1], + transformer.ff_0_proj_w->nb[1] * transformer.ff_0_proj_w->ne[1] / 2); // [in_channels * 4, ] + auto gate_b = ggml_view_1d(ctx, + transformer.ff_0_proj_b, + transformer.ff_0_proj_b->ne[0] / 2, + transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ] + x = ggml_reshape_2d(ctx, x, c, w * h * n); + auto x_in = x; + x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4] + auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4] + + gate = ggml_gelu_inplace(ctx, gate); + + x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4] + // fc + x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels] + } + + x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels] + + // residual + x = ggml_add(ctx, x, r); + } + + x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] + + // proj_out + x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w] + + x = ggml_add(ctx, x, x_in); + return x; + } +}; + #endif // __COMMON_HPP__ \ No newline at end of file diff --git a/control.hpp b/control.hpp index cf8d88be5..aa62bb10d 100644 --- a/control.hpp +++ b/control.hpp @@ -2,9 +2,11 @@ #define __CONTROL_HPP__ #include "ggml_extend.hpp" -#include "unet.hpp" +#include "common.hpp" #include "model.h" +#define CONTROL_NET_GRAPH_SIZE 1536 + /* =================================== ControlNet =================================== Reference: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/cldm/cldm.py @@ -531,7 +533,7 @@ struct ControlNet : public GGMLModule { struct ggml_tensor* context, struct ggml_tensor* t_emb = NULL) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * CONTROL_NET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params params = { @@ -543,7 +545,7 @@ struct ControlNet : public GGMLModule { struct ggml_context* ctx0 = ggml_init(params); - struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, UNET_GRAPH_SIZE, false); + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, CONTROL_NET_GRAPH_SIZE, false); // temporal tensors for transfer tensors from cpu to gpu if needed struct ggml_tensor* x_t = NULL; diff --git a/unet.hpp b/unet.hpp index de3621912..c4ece637d 100644 --- a/unet.hpp +++ b/unet.hpp @@ -3,468 +3,12 @@ #include "common.hpp" #include "ggml_extend.hpp" +#include "model.h" /*==================================================== UnetModel =====================================================*/ #define UNET_GRAPH_SIZE 10240 -struct ResBlock { - // network hparams - int channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4) - int emb_channels; // time_embed_dim - int out_channels; // mult * model_channels - - // network params - // in_layers - struct ggml_tensor* in_layer_0_w; // [channels, ] - struct ggml_tensor* in_layer_0_b; // [channels, ] - // in_layer_1 is nn.SILU() - struct ggml_tensor* in_layer_2_w; // [out_channels, channels, 3, 3] - struct ggml_tensor* in_layer_2_b; // [out_channels, ] - - // emb_layers - // emb_layer_0 is nn.SILU() - struct ggml_tensor* emb_layer_1_w; // [out_channels, emb_channels] - struct ggml_tensor* emb_layer_1_b; // [out_channels, ] - - // out_layers - struct ggml_tensor* out_layer_0_w; // [out_channels, ] - struct ggml_tensor* out_layer_0_b; // [out_channels, ] - // out_layer_1 is nn.SILU() - // out_layer_2 is nn.Dropout(), p = 0 for inference - struct ggml_tensor* out_layer_3_w; // [out_channels, out_channels, 3, 3] - struct ggml_tensor* out_layer_3_b; // [out_channels, ] - - // skip connection, only if out_channels != channels - struct ggml_tensor* skip_w; // [out_channels, channels, 1, 1] - struct ggml_tensor* skip_b; // [out_channels, ] - - size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_0_w/b - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // in_layer_2_w - mem_size += 5 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b - mem_size += out_channels * emb_channels * ggml_type_sizef(wtype); // emb_layer_1_w - mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_layer_3_w - - if (out_channels != channels) { - mem_size += out_channels * channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // skip_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // skip_b - } - return static_cast(mem_size); - } - - void init_params(struct ggml_context* ctx, ggml_type wtype) { - in_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); - in_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); - in_layer_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels); - in_layer_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - emb_layer_1_w = ggml_new_tensor_2d(ctx, wtype, emb_channels, out_channels); - emb_layer_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - out_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - out_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - out_layer_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels); - out_layer_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - if (out_channels != channels) { - skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels, out_channels); - skip_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - } - } - - void map_by_name(std::map& tensors, const std::string prefix) { - tensors[prefix + "in_layers.0.weight"] = in_layer_0_w; - tensors[prefix + "in_layers.0.bias"] = in_layer_0_b; - tensors[prefix + "in_layers.2.weight"] = in_layer_2_w; - tensors[prefix + "in_layers.2.bias"] = in_layer_2_b; - - tensors[prefix + "emb_layers.1.weight"] = emb_layer_1_w; - tensors[prefix + "emb_layers.1.bias"] = emb_layer_1_b; - - tensors[prefix + "out_layers.0.weight"] = out_layer_0_w; - tensors[prefix + "out_layers.0.bias"] = out_layer_0_b; - tensors[prefix + "out_layers.3.weight"] = out_layer_3_w; - tensors[prefix + "out_layers.3.bias"] = out_layer_3_b; - - if (out_channels != channels) { - tensors[prefix + "skip_connection.weight"] = skip_w; - tensors[prefix + "skip_connection.bias"] = skip_b; - } - } - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) { - // x: [N, channels, h, w] - // emb: [N, emb_channels] - - // in_layers - auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b); - h = ggml_silu_inplace(ctx, h); - h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] - - // emb_layers - auto emb_out = ggml_silu(ctx, emb); - emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels] - emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] - - // out_layers - h = ggml_add(ctx, h, emb_out); - h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b); - h = ggml_silu_inplace(ctx, h); - - // dropout, skip for inference - - h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w] - - // skip connection - if (out_channels != channels) { - x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w] - } - - h = ggml_add(ctx, h, x); - return h; // [N, out_channels, h, w] - } -}; - -struct SpatialTransformer { - int in_channels; // mult * model_channels - int n_head; // num_heads - int d_head; // in_channels // n_heads - int depth = 1; // 1 - int context_dim = 768; // hidden_size, 1024 for VERSION_2_x - - // group norm - struct ggml_tensor* norm_w; // [in_channels,] - struct ggml_tensor* norm_b; // [in_channels,] - - // proj_in - struct ggml_tensor* proj_in_w; // [in_channels, in_channels, 1, 1] - struct ggml_tensor* proj_in_b; // [in_channels,] - - // transformer - struct Transformer { - // layer norm 1 - struct ggml_tensor* norm1_w; // [in_channels, ] - struct ggml_tensor* norm1_b; // [in_channels, ] - - // attn1 - struct ggml_tensor* attn1_q_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_k_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_v_w; // [in_channels, in_channels] - - struct ggml_tensor* attn1_out_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_out_b; // [in_channels, ] - - // layer norm 2 - struct ggml_tensor* norm2_w; // [in_channels, ] - struct ggml_tensor* norm2_b; // [in_channels, ] - - // attn2 - struct ggml_tensor* attn2_q_w; // [in_channels, in_channels] - struct ggml_tensor* attn2_k_w; // [in_channels, context_dim] - struct ggml_tensor* attn2_v_w; // [in_channels, context_dim] - - struct ggml_tensor* attn2_out_w; // [in_channels, in_channels] - struct ggml_tensor* attn2_out_b; // [in_channels, ] - - // layer norm 3 - struct ggml_tensor* norm3_w; // [in_channels, ] - struct ggml_tensor* norm3_b; // [in_channels, ] - - // ff - struct ggml_tensor* ff_0_proj_w; // [in_channels * 4 * 2, in_channels] - struct ggml_tensor* ff_0_proj_b; // [in_channels * 4 * 2] - - struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4] - struct ggml_tensor* ff_2_b; // [in_channels,] - }; - - std::vector transformers; - - // proj_out - struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1] - struct ggml_tensor* proj_out_b; // [in_channels,] - - SpatialTransformer(int depth = 1) - : depth(depth) { - transformers.resize(depth); - } - - int get_num_tensors() { - return depth * 20 + 7; - } - - size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b - mem_size += 2 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // proj_in_w/proj_out_w - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // proj_in_b/proj_out_b - - // transformer - for (auto& transformer : transformers) { - mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1-3_w/b - mem_size += 6 * in_channels * in_channels * ggml_type_sizef(wtype); // attn1_q/k/v/out_w attn2_q/out_w - mem_size += 2 * in_channels * context_dim * ggml_type_sizef(wtype); // attn2_k/v_w - mem_size += in_channels * 4 * 2 * in_channels * ggml_type_sizef(wtype); // ff_0_proj_w - mem_size += in_channels * 4 * 2 * ggml_type_sizef(GGML_TYPE_F32); // ff_0_proj_b - mem_size += in_channels * 4 * in_channels * ggml_type_sizef(wtype); // ff_2_w - mem_size += in_channels * ggml_type_sizef(GGML_TYPE_F32); // ff_2_b - } - return static_cast(mem_size); - } - - void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { - norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - proj_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); - proj_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); - proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - // transformer - for (auto& transformer : transformers) { - transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - - transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - - transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); - transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); - - transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); - transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - } - } - - void map_by_name(std::map& tensors, const std::string prefix) { - tensors[prefix + "norm.weight"] = norm_w; - tensors[prefix + "norm.bias"] = norm_b; - tensors[prefix + "proj_in.weight"] = proj_in_w; - tensors[prefix + "proj_in.bias"] = proj_in_b; - - // transformer - for (int i = 0; i < transformers.size(); i++) { - auto& transformer = transformers[i]; - std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + "."; - tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w; - tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w; - tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w; - - tensors[transformer_prefix + "attn1.to_out.0.weight"] = transformer.attn1_out_w; - tensors[transformer_prefix + "attn1.to_out.0.bias"] = transformer.attn1_out_b; - - tensors[transformer_prefix + "ff.net.0.proj.weight"] = transformer.ff_0_proj_w; - tensors[transformer_prefix + "ff.net.0.proj.bias"] = transformer.ff_0_proj_b; - tensors[transformer_prefix + "ff.net.2.weight"] = transformer.ff_2_w; - tensors[transformer_prefix + "ff.net.2.bias"] = transformer.ff_2_b; - - tensors[transformer_prefix + "attn2.to_q.weight"] = transformer.attn2_q_w; - tensors[transformer_prefix + "attn2.to_k.weight"] = transformer.attn2_k_w; - tensors[transformer_prefix + "attn2.to_v.weight"] = transformer.attn2_v_w; - - tensors[transformer_prefix + "attn2.to_out.0.weight"] = transformer.attn2_out_w; - tensors[transformer_prefix + "attn2.to_out.0.bias"] = transformer.attn2_out_b; - - tensors[transformer_prefix + "norm1.weight"] = transformer.norm1_w; - tensors[transformer_prefix + "norm1.bias"] = transformer.norm1_b; - tensors[transformer_prefix + "norm2.weight"] = transformer.norm2_w; - tensors[transformer_prefix + "norm2.bias"] = transformer.norm2_b; - tensors[transformer_prefix + "norm3.weight"] = transformer.norm3_w; - tensors[transformer_prefix + "norm3.bias"] = transformer.norm3_b; - } - - tensors[prefix + "proj_out.weight"] = proj_out_w; - tensors[prefix + "proj_out.bias"] = proj_out_b; - } - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { - // x: [N, in_channels, h, w] - // context: [N, max_position, hidden_size(aka context_dim)] - auto x_in = x; - x = ggml_nn_group_norm(ctx, x, norm_w, norm_b); - // proj_in - x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w] - - // transformer - const int64_t n = x->ne[3]; - const int64_t c = x->ne[2]; - const int64_t h = x->ne[1]; - const int64_t w = x->ne[0]; - const int64_t max_position = context->ne[1]; - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels] - - for (auto& transformer : transformers) { - auto r = x; - // layer norm 1 - x = ggml_reshape_2d(ctx, x, c, w * h * n); - x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b); - - // self-attention - { - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) - q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); -#endif - q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn1_k_w, x); // [N * h * w, in_channels] - k = ggml_reshape_4d(ctx, k, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - k = ggml_reshape_3d(ctx, k, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn1_v_w, x); // [N * h * w, in_channels] - v = ggml_reshape_4d(ctx, v, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] - v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] - -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] -#else - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] - // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - kq = ggml_soft_max_inplace(ctx, kq); - - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] -#endif - kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, h * w, n_head, d_head] - - // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); - x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); - - x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b); - - x = ggml_reshape_4d(ctx, x, c, w, h, n); - } - - x = ggml_add(ctx, x, r); - r = x; - - // layer norm 2 - x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b); - - // cross-attention - { - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] - struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) - q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); -#endif - q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn2_k_w, context); // [N * max_position, in_channels] - k = ggml_reshape_4d(ctx, k, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, max_position, d_head] - k = ggml_reshape_3d(ctx, k, d_head, max_position, n_head * n); // [N * n_head, max_position, d_head] - - struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn2_v_w, context); // [N * max_position, in_channels] - v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] - v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] -#else - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] - // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - kq = ggml_soft_max_inplace(ctx, kq); - - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] -#endif - kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); - - // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels] - x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels] - - x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b); - - x = ggml_reshape_4d(ctx, x, c, w, h, n); - } - - x = ggml_add(ctx, x, r); - r = x; - - // layer norm 3 - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b); - - // ff - { - // GEGLU - auto x_w = ggml_view_2d(ctx, - transformer.ff_0_proj_w, - transformer.ff_0_proj_w->ne[0], - transformer.ff_0_proj_w->ne[1] / 2, - transformer.ff_0_proj_w->nb[1], - 0); // [in_channels * 4, in_channels] - auto x_b = ggml_view_1d(ctx, - transformer.ff_0_proj_b, - transformer.ff_0_proj_b->ne[0] / 2, - 0); // [in_channels * 4, in_channels] - auto gate_w = ggml_view_2d(ctx, - transformer.ff_0_proj_w, - transformer.ff_0_proj_w->ne[0], - transformer.ff_0_proj_w->ne[1] / 2, - transformer.ff_0_proj_w->nb[1], - transformer.ff_0_proj_w->nb[1] * transformer.ff_0_proj_w->ne[1] / 2); // [in_channels * 4, ] - auto gate_b = ggml_view_1d(ctx, - transformer.ff_0_proj_b, - transformer.ff_0_proj_b->ne[0] / 2, - transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ] - x = ggml_reshape_2d(ctx, x, c, w * h * n); - auto x_in = x; - x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4] - auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4] - - gate = ggml_gelu_inplace(ctx, gate); - - x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4] - // fc - x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels] - } - - x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels] - - // residual - x = ggml_add(ctx, x, r); - } - - x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] - - // proj_out - x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w] - - x = ggml_add(ctx, x, x_in); - return x; - } -}; - // ldm.modules.diffusionmodules.openaimodel.UNetModel struct UNetModel : public GGMLModule { SDVersion version = VERSION_1_x; @@ -893,7 +437,7 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* timesteps, struct ggml_tensor* context, std::vector control, - struct ggml_tensor* control_strength, + float control_net_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] @@ -952,7 +496,7 @@ struct UNetModel : public GGMLModule { h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] if(control.size() > 0) { - auto cs = ggml_scale_inplace(ctx0, control[control.size() - 1], control_strength); + auto cs = ggml_scale_inplace(ctx0, control[control.size() - 1], control_net_strength); h = ggml_add(ctx0, h, cs); // middle control } @@ -964,7 +508,7 @@ struct UNetModel : public GGMLModule { hs.pop_back(); if(control.size() > 0) { - auto cs = ggml_scale_inplace(ctx0, control[control_offset], control_strength); + auto cs = ggml_scale_inplace(ctx0, control[control_offset], control_net_strength); h_skip = ggml_add(ctx0, h_skip, cs); // control net condition control_offset--; } @@ -1022,7 +566,6 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; struct ggml_tensor* y_t = NULL; - struct ggml_tensor* control_strength = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); std::vector control_t; // it's performing a compute, check if backend isn't cpu @@ -1082,12 +625,7 @@ struct UNetModel : public GGMLModule { control_t = control; } - ggml_allocr_alloc(compute_allocr, control_strength); - if(!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, control_strength, &control_net_strength, 0, sizeof(float)); - } - - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_strength, t_emb_t, y_t); + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_net_strength, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); From cd08dad5b77ae83bb4a2d9ae4d18da72e33d87d5 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 11:19:45 -0500 Subject: [PATCH 21/29] refactor: change ggml_type_sizef to ggml_row_size --- clip.hpp | 30 +++++++++++++-------------- common.hpp | 58 ++++++++++++++++++++++++++--------------------------- control.hpp | 26 ++++++++++++------------ unet.hpp | 30 +++++++++++++-------------- vae.hpp | 52 +++++++++++++++++++++++------------------------ 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/clip.hpp b/clip.hpp index 4f1f9534b..b98f6fd7a 100644 --- a/clip.hpp +++ b/clip.hpp @@ -209,13 +209,13 @@ struct ResidualAttentionBlock { struct ggml_tensor* ln2_b; // [hidden_size, ] size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 4 * hidden_size * hidden_size * ggml_type_sizef(wtype); // q_w/k_w/v_w/out_w - mem_size += 8 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b - mem_size += 2 * hidden_size * intermediate_size * ggml_type_sizef(wtype); // fc1_w/fc2_w - mem_size += intermediate_size * ggml_type_sizef(GGML_TYPE_F32); // fc1_b - mem_size += hidden_size * ggml_type_sizef(GGML_TYPE_F32); // fc2_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += 4 * ggml_row_size(wtype, hidden_size * hidden_size); // q_w/k_w/v_w/out_w + mem_size += 8 * ggml_row_size(GGML_TYPE_F32, hidden_size); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b + mem_size += 2 * ggml_row_size(wtype, hidden_size * intermediate_size); // fc1_w/fc2_w + mem_size += ggml_row_size(GGML_TYPE_F32, intermediate_size); // fc1_b + mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size); // fc2_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { @@ -411,21 +411,21 @@ struct CLIPTextModel { } size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids - mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight - mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_I32, hidden_size * max_position_embeddings); // position_ids + mem_size += ggml_row_size(wtype, hidden_size * vocab_size); // token_embed_weight + mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // position_embed_weight if(version == OPENAI_CLIP_VIT_L_14) { - mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // token_embed_custom + mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // token_embed_custom } for (int i = 0; i < num_hidden_layers; i++) { mem_size += resblocks[i].calculate_mem_size(wtype); } - mem_size += 2 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // final_ln_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, hidden_size); // final_ln_w/b if (version == OPEN_CLIP_VIT_BIGG_14) { - mem_size += hidden_size * projection_dim * ggml_type_sizef(GGML_TYPE_F32); // text_projection + mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size * projection_dim); // text_projection } - return static_cast(mem_size); + return mem_size; } void map_by_name(std::map& tensors, const std::string prefix) { diff --git a/common.hpp b/common.hpp index 37692d58d..a71e4d370 100644 --- a/common.hpp +++ b/common.hpp @@ -15,10 +15,10 @@ struct DownSample { bool vae_downsample = false; size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_type wtype) { @@ -59,10 +59,10 @@ struct UpSample { struct ggml_tensor* conv_b; // [out_channels,] size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_type wtype) { @@ -115,18 +115,18 @@ struct ResBlock { struct ggml_tensor* skip_b; // [out_channels, ] size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_0_w/b - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // in_layer_2_w - mem_size += 5 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b - mem_size += out_channels * emb_channels * ggml_type_sizef(wtype); // emb_layer_1_w - mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_layer_3_w + size_t mem_size = 0; + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, channels); // in_layer_0_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // in_layer_2_w + mem_size += 5 * ggml_row_size(GGML_TYPE_F32, out_channels); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b + mem_size += ggml_row_size(wtype, out_channels * emb_channels); // emb_layer_1_w + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // out_layer_3_w if (out_channels != channels) { - mem_size += out_channels * channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // skip_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // skip_b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 1 * 1); // skip_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // skip_b } - return static_cast(mem_size); + return mem_size; } void init_params(struct ggml_context* ctx, ggml_type wtype) { @@ -271,22 +271,22 @@ struct SpatialTransformer { } size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b - mem_size += 2 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // proj_in_w/proj_out_w - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // proj_in_b/proj_out_b + size_t mem_size = 0; + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b + mem_size += 2 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // proj_in_w/proj_out_w + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // proj_in_b/proj_out_b // transformer for (auto& transformer : transformers) { - mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1-3_w/b - mem_size += 6 * in_channels * in_channels * ggml_type_sizef(wtype); // attn1_q/k/v/out_w attn2_q/out_w - mem_size += 2 * in_channels * context_dim * ggml_type_sizef(wtype); // attn2_k/v_w - mem_size += in_channels * 4 * 2 * in_channels * ggml_type_sizef(wtype); // ff_0_proj_w - mem_size += in_channels * 4 * 2 * ggml_type_sizef(GGML_TYPE_F32); // ff_0_proj_b - mem_size += in_channels * 4 * in_channels * ggml_type_sizef(wtype); // ff_2_w - mem_size += in_channels * ggml_type_sizef(GGML_TYPE_F32); // ff_2_b + mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1-3_w/b + mem_size += 6 * ggml_row_size(wtype, in_channels * in_channels); // attn1_q/k/v/out_w attn2_q/out_w + mem_size += 2 * ggml_row_size(wtype, in_channels * context_dim); // attn2_k/v_w + mem_size += ggml_row_size(wtype, in_channels * 4 * 2 * in_channels ); // ff_0_proj_w + mem_size += ggml_row_size(GGML_TYPE_F32, in_channels * 4 * 2); // ff_0_proj_b + mem_size += ggml_row_size(wtype, in_channels * 4 * in_channels); // ff_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, in_channels); // ff_2_b } - return static_cast(mem_size); + return mem_size; } void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { diff --git a/control.hpp b/control.hpp index aa62bb10d..eddfd1470 100644 --- a/control.hpp +++ b/control.hpp @@ -44,7 +44,7 @@ struct CNHintBlock { } mem_size += model_channels * feat_channels[3] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_final_w mem_size += model_channels * ggml_type_size(GGML_TYPE_F32); // conv_final_b - return static_cast(mem_size); + return mem_size; } void init_params(struct ggml_context* ctx) { @@ -223,15 +223,15 @@ struct ControlNet : public GGMLModule { } size_t calculate_mem_size() { - double mem_size = 0; + size_t mem_size = 0; mem_size += input_hint_block.calculate_mem_size(); - mem_size += time_embed_dim * model_channels * ggml_type_sizef(wtype); // time_embed_0_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_0_b - mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + mem_size += ggml_row_size(wtype, time_embed_dim * model_channels); // time_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // time_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32,time_embed_dim); // time_embed_2_b - mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w - mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b + mem_size += ggml_row_size(GGML_TYPE_F16, model_channels * in_channels * 3 * 3); // input_block_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, model_channels); // input_block_0_b // input_blocks int ds = 1; @@ -250,8 +250,8 @@ struct ControlNet : public GGMLModule { } for (int i = 0; i < num_zero_convs; i++) { - mem_size += zero_convs[i].channels * zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F16); - mem_size += zero_convs[i].channels * ggml_type_sizef(GGML_TYPE_F32); + mem_size += ggml_row_size(GGML_TYPE_F16, zero_convs[i].channels * zero_convs[i].channels); + mem_size += ggml_row_size(GGML_TYPE_F32, zero_convs[i].channels); } // middle_block @@ -259,10 +259,10 @@ struct ControlNet : public GGMLModule { mem_size += middle_block_1.calculate_mem_size(wtype); mem_size += middle_block_2.calculate_mem_size(wtype); - mem_size += middle_out_channel * middle_out_channel * ggml_type_sizef(GGML_TYPE_F16); // middle_block_out_w - mem_size += middle_out_channel * ggml_type_sizef(GGML_TYPE_F32); // middle_block_out_b + mem_size += ggml_row_size(GGML_TYPE_F16, middle_out_channel * middle_out_channel); // middle_block_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, middle_out_channel); // middle_block_out_b - return static_cast(mem_size); + return mem_size; } size_t get_num_tensors() { diff --git a/unet.hpp b/unet.hpp index c4ece637d..f3b7128a2 100644 --- a/unet.hpp +++ b/unet.hpp @@ -180,21 +180,21 @@ struct UNetModel : public GGMLModule { } size_t calculate_mem_size() { - double mem_size = 0; - mem_size += time_embed_dim * model_channels * ggml_type_sizef(wtype); // time_embed_0_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_0_b - mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + size_t mem_size = 0; + mem_size += ggml_row_size(wtype, time_embed_dim * model_channels); // time_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // time_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_2_b if (version == VERSION_XL) { - mem_size += time_embed_dim * adm_in_channels * ggml_type_sizef(wtype); // label_embed_0_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_0_b - mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // label_embed_2_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_2_b + mem_size += ggml_row_size(wtype, time_embed_dim * adm_in_channels); // label_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // label_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // label_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // label_embed_2_b } - mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w - mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b + mem_size += ggml_row_size(GGML_TYPE_F16, model_channels * in_channels * 3 * 3); // input_block_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, model_channels); // input_block_0_b // input_blocks int ds = 1; @@ -235,11 +235,11 @@ struct UNetModel : public GGMLModule { } // out - mem_size += 2 * model_channels * ggml_type_sizef(GGML_TYPE_F32); // out_0_w/b - mem_size += out_channels * model_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_2_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // out_2_b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, model_channels); // out_0_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * model_channels * 3 * 3); // out_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // out_2_b - return static_cast(mem_size); + return mem_size; } size_t get_num_tensors() { diff --git a/vae.hpp b/vae.hpp index 2abf93557..00b122729 100644 --- a/vae.hpp +++ b/vae.hpp @@ -32,14 +32,14 @@ struct ResnetBlock { size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1_w/b - mem_size += out_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv1_w - mem_size += 4 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // conv1_b/norm2_w/norm2_b/conv2_b - mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv2_w + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 3 * 3); // conv1_w + mem_size += 4 * ggml_row_size(GGML_TYPE_F32, out_channels); // conv1_b/norm2_w/norm2_b/conv2_b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // conv2_w if (out_channels != in_channels) { - mem_size += out_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // nin_shortcut_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // nin_shortcut_b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 1 * 1); // nin_shortcut_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // nin_shortcut_b } return static_cast(mem_size); } @@ -120,8 +120,8 @@ struct AttnBlock { size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; - mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b/q_b/k_v/v_b/proj_out_b - mem_size += 4 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // q_w/k_w/v_w/proj_out_w // object overhead + mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b/q_b/k_v/v_b/proj_out_b + mem_size += 4 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // q_w/k_w/v_w/proj_out_w // object overhead return static_cast(mem_size); } @@ -269,17 +269,17 @@ struct Encoder { } size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; + size_t mem_size = 0; int len_mults = sizeof(ch_mult) / sizeof(int); int block_in = ch * ch_mult[len_mults - 1]; - mem_size += ch * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_in_w - mem_size += ch * ggml_type_sizef(GGML_TYPE_F32); // conv_in_b + mem_size += ggml_row_size(GGML_TYPE_F16, ch * in_channels * 3 * 3); // conv_in_w + mem_size += ggml_row_size(GGML_TYPE_F32, ch); // conv_in_b - mem_size += 2 * block_in * ggml_type_sizef(GGML_TYPE_F32); // norm_out_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, block_in); // norm_out_w/b - mem_size += z_channels * 2 * block_in * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_out_w - mem_size += z_channels * 2 * ggml_type_sizef(GGML_TYPE_F32); // conv_out_b + mem_size += ggml_row_size(GGML_TYPE_F16, z_channels * 2 * block_in * 3 * 3); // conv_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, z_channels * 2); // conv_out_b mem_size += mid.block_1.calculate_mem_size(wtype); mem_size += mid.attn_1.calculate_mem_size(wtype); @@ -294,7 +294,7 @@ struct Encoder { } } - return static_cast(mem_size); + return mem_size; } void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { @@ -436,13 +436,13 @@ struct Decoder { int len_mults = sizeof(ch_mult) / sizeof(int); int block_in = ch * ch_mult[len_mults - 1]; - mem_size += block_in * z_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_in_w - mem_size += block_in * ggml_type_sizef(GGML_TYPE_F32); // conv_in_b + mem_size += ggml_row_size(GGML_TYPE_F16, block_in * z_channels * 3 * 3); // conv_in_w + mem_size += ggml_row_size(GGML_TYPE_F32, block_in); // conv_in_b - mem_size += 2 * (ch * ch_mult[0]) * ggml_type_sizef(GGML_TYPE_F32); // norm_out_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, (ch * ch_mult[0])); // norm_out_w/b - mem_size += (ch * ch_mult[0]) * out_ch * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_out_w - mem_size += out_ch * ggml_type_sizef(GGML_TYPE_F32); // conv_out_b + mem_size += ggml_row_size(GGML_TYPE_F16, (ch * ch_mult[0]) * out_ch * 3 * 3); // conv_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_ch); // conv_out_b mem_size += mid.block_1.calculate_mem_size(wtype); mem_size += mid.attn_1.calculate_mem_size(wtype); @@ -606,19 +606,19 @@ struct AutoEncoderKL : public GGMLModule { } size_t calculate_mem_size() { - double mem_size = 0; + size_t mem_size = 0; if (!decode_only) { - mem_size += 2 * embed_dim * 2 * dd_config.z_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // quant_conv_w - mem_size += 2 * embed_dim * ggml_type_sizef(GGML_TYPE_F32); // quant_conv_b + mem_size += ggml_row_size(GGML_TYPE_F16, 2 * embed_dim * 2 * dd_config.z_channels * 1 * 1); // quant_conv_w + mem_size += ggml_row_size(GGML_TYPE_F32, 2 * embed_dim); // quant_conv_b mem_size += encoder.calculate_mem_size(wtype); } - mem_size += dd_config.z_channels * embed_dim * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // post_quant_conv_w - mem_size += dd_config.z_channels * ggml_type_sizef(GGML_TYPE_F32); // post_quant_conv_b + mem_size += ggml_row_size(GGML_TYPE_F16, dd_config.z_channels * embed_dim * 1 * 1); // post_quant_conv_w + mem_size += ggml_row_size(GGML_TYPE_F32, dd_config.z_channels); // post_quant_conv_b mem_size += decoder.calculate_mem_size(wtype); - return static_cast(mem_size); + return mem_size; } size_t get_num_tensors() { From bcf10aa29f575841839e8188c00cc4d3db4651be Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 11:57:07 -0500 Subject: [PATCH 22/29] process hint once time + remove unused code --- clip.hpp | 8 ++--- control.hpp | 69 ++++++++++++++++++++++++++++++++++++++------ esrgan.hpp | 2 +- ggml_extend.hpp | 17 ++++------- stable-diffusion.cpp | 10 ++++--- tae.hpp | 2 +- unet.hpp | 10 +++---- vae.hpp | 2 +- 8 files changed, 84 insertions(+), 36 deletions(-) diff --git a/clip.hpp b/clip.hpp index b98f6fd7a..74718e445 100644 --- a/clip.hpp +++ b/clip.hpp @@ -992,7 +992,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { ggml_allocr_alloc(allocr, input_ids); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set_and_sync(backend, input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); + ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); } struct ggml_tensor* input_ids2 = NULL; @@ -1014,7 +1014,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { // printf("\n"); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set_and_sync(backend, input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); + ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); } } @@ -1027,12 +1027,12 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { // really bad, there is memory inflexibility (this is for host<->device memory conflicts) void* freeze_data = malloc(ggml_nbytes(text_model.token_embed_weight)); ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_weight, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); - ggml_backend_tensor_set_and_sync(backend, embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_set(embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); free(freeze_data); // concatenate custom embeddings void* custom_data = malloc(ggml_nbytes(text_model.token_embed_custom)); ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_custom, custom_data, 0, ggml_nbytes(text_model.token_embed_custom)); - ggml_backend_tensor_set_and_sync(backend, embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)); + ggml_backend_tensor_set(embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)); free(custom_data); } } diff --git a/control.hpp b/control.hpp index eddfd1470..1c5ed4e72 100644 --- a/control.hpp +++ b/control.hpp @@ -153,6 +153,7 @@ struct ControlNet : public GGMLModule { std::vector controls; // (12 input block outputs, 1 middle block output) SD 1.5 ControlNet() { + name = "controlnet"; // input_blocks std::vector input_block_chans; input_block_chans.push_back(model_channels); @@ -461,6 +462,51 @@ struct ControlNet : public GGMLModule { tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; } + struct ggml_cgraph* build_graph_hint(struct ggml_tensor* hint) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + struct ggml_context* ctx0 = ggml_init(params); + struct ggml_cgraph* gf = ggml_new_graph(ctx0); + // temporal tensors for transfer tensors from cpu to gpu if needed + struct ggml_tensor* hint_t = NULL; + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + hint_t = ggml_dup_tensor(ctx0, hint); + ggml_allocr_alloc(compute_allocr, hint_t); + // pass data to device backend + if (!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); + } + } else { + // if it's cpu backend just pass the same tensors + hint_t = hint; + } + struct ggml_tensor* out = input_hint_block.forward(ctx0, hint_t); + ggml_build_forward_expand(gf, out); + ggml_free(ctx0); + return gf; + } + + void process_hint(struct ggml_tensor* output, int n_threads, struct ggml_tensor* hint) { + // compute buffer size + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph_hint(hint); + }; + GGMLModule::alloc_compute_buffer(get_graph); + // perform computation + GGMLModule::compute(get_graph, n_threads, output); + GGMLModule::free_compute_buffer(); + } + void forward(struct ggml_cgraph* gf, struct ggml_context* ctx0, struct ggml_tensor* x, @@ -481,14 +527,12 @@ struct ControlNet : public GGMLModule { emb = ggml_silu_inplace(ctx0, emb); emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] - auto guided_hint = input_hint_block.forward(ctx0, hint); - // input_blocks int zero_conv_offset = 0; // input block 0 struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] - h = ggml_add(ctx0, h, guided_hint); + h = ggml_add(ctx0, h, hint); auto h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); @@ -573,14 +617,14 @@ struct ControlNet : public GGMLModule { } // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, x_t, x->data, 0, ggml_nbytes(x)); - ggml_backend_tensor_set_and_sync(backend, context_t, context->data, 0, ggml_nbytes(context)); - ggml_backend_tensor_set_and_sync(backend, hint_t, hint->data, 0, ggml_nbytes(hint)); + ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); if (timesteps_t != NULL) { - ggml_backend_tensor_set_and_sync(backend, timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { - ggml_backend_tensor_set_and_sync(backend, t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } } } else { @@ -605,7 +649,7 @@ struct ControlNet : public GGMLModule { struct ggml_tensor* t_emb = NULL) { { struct ggml_init_params params; - params.mem_size = static_cast(13 * ggml_tensor_overhead()) + 256; + params.mem_size = static_cast(14 * ggml_tensor_overhead()) + 256; params.mem_buffer = NULL; params.no_alloc = true; control_ctx = ggml_init(params); @@ -639,6 +683,13 @@ struct ControlNet : public GGMLModule { }; GGMLModule::compute(get_graph, n_threads, NULL); } + + void end() { + GGMLModule::free_compute_buffer(); + ggml_free(control_ctx); + ggml_backend_buffer_free(control_buffer); + control_buffer = NULL; + } }; #endif // __CONTROL_HPP__ \ No newline at end of file diff --git a/esrgan.hpp b/esrgan.hpp index 2b4ff5d64..90194c0d9 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -386,7 +386,7 @@ struct ESRGAN : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, x_, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); } } else { x_ = x; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c14721269..60ab430c1 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -460,19 +460,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct return x; } -__STATIC_INLINE__ void ggml_backend_tensor_set_and_sync(ggml_backend_t backend, struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { -#ifdef SD_USE_CUBLAS - ggml_backend_tensor_set_async(backend, tensor, data, offset, size); - ggml_backend_synchronize(backend); -#else - ggml_backend_tensor_set(tensor, data, offset, size); -#endif -} - __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { #ifdef SD_USE_CUBLAS - ggml_backend_tensor_get_async(backend, tensor, data, offset, size); - ggml_backend_synchronize(backend); + if(!ggml_backend_is_cpu(backend)) { + ggml_backend_tensor_get_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); + } else { + ggml_backend_tensor_get(tensor, data, offset, size); + } #else ggml_backend_tensor_get(tensor, data, offset, size); #endif diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5263d7203..a52ef981e 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -544,9 +544,11 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - + struct ggml_tensor* guided_hint = NULL; if(control_hint != NULL) { - control_net.alloc_compute_buffer(noised_input, control_hint, c, t_emb); + guided_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, noised_input->ne[0], noised_input->ne[1], diffusion_model.model_channels, 1); + control_net.process_hint(guided_hint, n_threads, control_hint); + control_net.alloc_compute_buffer(noised_input, guided_hint, c, t_emb); } diffusion_model.alloc_compute_buffer(noised_input, c, control_net.controls, t_emb, c_vector); @@ -600,7 +602,7 @@ class StableDiffusionGGML { // cond if(control_hint != NULL) { - control_net.compute(n_threads, noised_input, control_hint, c, t_emb); + control_net.compute(n_threads, noised_input, guided_hint, c, t_emb); } diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, control_strength, t_emb, c_vector); @@ -608,7 +610,7 @@ class StableDiffusionGGML { if (has_unconditioned) { // uncond if(control_hint != NULL) { - control_net.compute(n_threads, noised_input, control_hint, uc, t_emb); + control_net.compute(n_threads, noised_input, guided_hint, uc, t_emb); } diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, control_strength, t_emb, uc_vector); diff --git a/tae.hpp b/tae.hpp index 981377523..405ac9c43 100644 --- a/tae.hpp +++ b/tae.hpp @@ -549,7 +549,7 @@ struct TinyAutoEncoder : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; diff --git a/unet.hpp b/unet.hpp index f3b7128a2..2c9e7c929 100644 --- a/unet.hpp +++ b/unet.hpp @@ -589,16 +589,16 @@ struct UNetModel : public GGMLModule { } // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, x_t, x->data, 0, ggml_nbytes(x)); - ggml_backend_tensor_set_and_sync(backend, context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); if (timesteps_t != NULL) { - ggml_backend_tensor_set_and_sync(backend, timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { - ggml_backend_tensor_set_and_sync(backend, t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } if (y != NULL) { - ggml_backend_tensor_set_and_sync(backend, y_t, y->data, 0, ggml_nbytes(y)); + ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); } } } else { diff --git a/vae.hpp b/vae.hpp index 00b122729..38af54085 100644 --- a/vae.hpp +++ b/vae.hpp @@ -707,7 +707,7 @@ struct AutoEncoderKL : public GGMLModule { // pass data to device backend if (!ggml_allocr_is_measure(compute_allocr)) { - ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; From 31c4367b0ce0bb48a8da0c618acc7b6e75b49314 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 15:24:39 -0500 Subject: [PATCH 23/29] fix enable vae-tiling --- model.cpp | 6 +----- stable-diffusion.cpp | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/model.cpp b/model.cpp index 19f9ece21..77153e6fd 100644 --- a/model.cpp +++ b/model.cpp @@ -1312,11 +1312,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend size_t nbytes_to_read = tensor_storage.nbytes_to_read(); - if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) -#ifdef SD_USE_METAL - || ggml_backend_is_metal(backend) -#endif - ) { + if (dst_tensor->buffer == NULL || ggml_backend_buffer_is_host(dst_tensor->buffer)) { // for the CPU and Metal backend, we can copy directly into the tensor if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index a52ef981e..d78a90017 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -112,7 +112,7 @@ class StableDiffusionGGML { const std::string control_net_path, const std::string embeddings_path, const std::string& taesd_path, - bool vae_tiling, + bool vae_tiling_, ggml_type wtype, schedule_t schedule, bool control_net_cpu) { @@ -140,6 +140,8 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s'", model_path.c_str()); ModelLoader model_loader; + vae_tiling = vae_tiling_; + if (!model_loader.init_from_file(model_path)) { LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); return false; From 58be850441f47953be9594b3e590b609e00d47a9 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 18:35:04 -0500 Subject: [PATCH 24/29] release controlnet memory --- stable-diffusion.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index d78a90017..58e1f3f25 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1024,6 +1024,7 @@ class StableDiffusionGGML { LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); abort(); } + control_net.end(); diffusion_model.free_compute_buffer(); return x; } From 2b516a4a3abe1e7fe2ca6231d108435efa485cf2 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 5 Jan 2024 19:19:19 -0500 Subject: [PATCH 25/29] remove unused cmake option --- CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 952cfb771..74334d2d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,6 @@ option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) -option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -35,9 +34,6 @@ if(SD_CUBLAS) message("Use CUBLAS as backend stable-diffusion") set(GGML_CUBLAS ON) add_definitions(-DSD_USE_CUBLAS) - if(SD_FAST_SOFTMAX) - set(GGML_CUDA_FAST_SOFTMAX ON) - endif() endif() if(SD_METAL) From cbd6a7e69da962949154bcde039102cf54c9119e Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Jan 2024 12:53:55 +0800 Subject: [PATCH 26/29] use on_new_token_cb_t --- clip.hpp | 529 ++++++++++++++++++++++++++++--------------------------- model.h | 1 - 2 files changed, 271 insertions(+), 259 deletions(-) diff --git a/clip.hpp b/clip.hpp index 74718e445..616b8f119 100644 --- a/clip.hpp +++ b/clip.hpp @@ -67,6 +67,252 @@ std::vector> bytes_to_unicode() { return byte_unicode_pairs; } +// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + +typedef std::function&)> on_new_token_cb_t; + +class CLIPTokenizer { +private: + SDVersion version = VERSION_1_x; + std::map byte_encoder; + std::map encoder; + std::map, int> bpe_ranks; + std::regex pat; + + static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + // String contains only whitespace characters + return ""; + } + + return str.substr(start, end - start + 1); + } + + static std::string whitespace_clean(std::string text) { + text = std::regex_replace(text, std::regex(R"(\s+)"), " "); + text = strip(text); + return text; + } + + static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.size() == 0) { + return pairs; + } + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < subwords.size(); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; + } + +public: + CLIPTokenizer(SDVersion version = VERSION_1_x) + : version(version) {} + + void load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + // for (auto & pair: byte_unicode_pairs) { + // std::cout << pair.first << ": " << pair.second << std::endl; + // } + std::vector merges; + size_t start = 0; + size_t pos; + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { + merges.push_back(merges_utf32_str.substr(start, pos - start)); + start = pos + 1; + } + // LOG_DEBUG("merges size %llu", merges.size()); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + } + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second + utf8_to_utf32("")); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + vocab.push_back(utf8_to_utf32("<|startoftext|>")); + vocab.push_back(utf8_to_utf32("<|endoftext|>")); + LOG_DEBUG("vocab size: %llu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i++; + } + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + }; + + std::u32string bpe(const std::u32string& token) { + std::vector word; + + for (int i = 0; i < token.size() - 1; i++) { + word.emplace_back(1, token[i]); + } + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return token + utf8_to_utf32(""); + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + std::u32string result; + for (int i = 0; i < word.size(); i++) { + result += word[i]; + if (i != word.size() - 1) { + result += utf8_to_utf32(" "); + } + } + + return result; + } + + std::vector tokenize(std::string text, + on_new_token_cb_t on_new_token_cb, + size_t max_length = 0, + bool padding = false) { + std::vector tokens = encode(text, on_new_token_cb); + tokens.insert(tokens.begin(), BOS_TOKEN_ID); + if (max_length > 0) { + if (tokens.size() > max_length - 1) { + tokens.resize(max_length - 1); + tokens.push_back(EOS_TOKEN_ID); + } else { + tokens.push_back(EOS_TOKEN_ID); + if (padding) { + int pad_token_id = PAD_TOKEN_ID; + if (version == VERSION_2_x) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); + } + } + } + return tokens; + } + + std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { + std::string original_text = text; + std::vector bpe_tokens; + text = whitespace_clean(text); + std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); + + std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", + std::regex::icase); + + std::smatch matches; + std::string str = text; + std::vector token_strs; + while (std::regex_search(str, matches, pat)) { + bool skip = on_new_token_cb(str, bpe_tokens); + if (skip) { + continue; + } + for (auto& token : matches) { + std::string token_str = token.str(); + std::u32string utf32_token; + for (int i = 0; i < token_str.length(); i++) { + char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + size_t start = 0; + size_t pos; + while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { + auto bpe_str = bpe_strs.substr(start, pos - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + + start = pos + 1; + } + auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + str = matches.suffix(); + } + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + return bpe_tokens; + } +}; + // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345 // // Parses a string with attention tokens and returns a list of pairs: text and its associated weight. @@ -558,263 +804,6 @@ struct CLIPTextModel { } }; - -// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py -class CLIPTokenizer { -private: - SDVersion version = VERSION_1_x; - std::map byte_encoder; - std::map encoder; - std::map, int> bpe_ranks; - std::regex pat; - - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - -public: - CLIPTokenizer(SDVersion version = VERSION_1_x) - : version(version) {} - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - // LOG_DEBUG("merges size %llu", merges.size()); - GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin() + 1, merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second + utf8_to_utf32("")); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|startoftext|>")); - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i++; - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - }; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size() - 1; i++) { - word.emplace_back(1, token[i]); - } - word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token + utf8_to_utf32(""); - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, CLIPTextModel& text_model, size_t max_length = 0, bool padding = false) { - std::vector tokens = encode(text, text_model); - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - } else { - tokens.push_back(EOS_TOKEN_ID); - if (padding) { - int pad_token_id = PAD_TOKEN_ID; - if (version == VERSION_2_x) { - pad_token_id = 0; - } - tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); - } - } - } - return tokens; - } - - std::vector encode(std::string text, CLIPTextModel& text_model) { - std::string original_text = text; - std::vector bpe_tokens; - text = whitespace_clean(text); - std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); - - std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", - std::regex::icase); - - std::smatch matches; - std::string str = text; - std::vector token_strs; - while (std::regex_search(str, matches, pat)) { - size_t word_end = str.find(","); - std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); - embd_name = trim(embd_name); - std::string embd_path = path_join(text_model.embd_dir, embd_name + ".pt"); - if(!file_exists(embd_path)) { - embd_path = path_join(text_model.embd_dir, embd_name + ".ckpt"); - } - if(!file_exists(embd_path)) { - embd_path = path_join(text_model.embd_dir, embd_name + ".safetensors"); - } - if(file_exists(embd_path)) { - if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { - if(word_end != std::string::npos) { - str = str.substr(word_end); - } else { - str = ""; - } - continue; - } - } - for (auto& token : matches) { - std::string token_str = token.str(); - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - str = matches.suffix(); - } - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; - } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; - // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { @@ -924,12 +913,36 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); } + auto on_new_token_cb = [&] (std::string& str, std::vector &bpe_tokens) -> bool { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = path_join(text_model.embd_dir, embd_name + ".pt"); + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".ckpt"); + } + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".safetensors"); + } + if(file_exists(embd_path)) { + if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { + if(word_end != std::string::npos) { + str = str.substr(word_end); + } else { + str = ""; + } + return true; + } + } + return false; + }; + std::vector tokens; std::vector weights; for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.encode(curr_text, text_model); + std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } diff --git a/model.h b/model.h index 4b692a303..13665a7ea 100644 --- a/model.h +++ b/model.h @@ -93,7 +93,6 @@ struct TensorStorage { }; typedef std::function on_new_tensor_cb_t; -typedef std::function on_new_token_cb_t; class ModelLoader { protected: From d5de88e1905c73206ee49733e551a68fdd8ff6a4 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 22 Jan 2024 09:36:46 -0500 Subject: [PATCH 27/29] fix taesd option --- .gitignore | 1 - stable-diffusion.cpp | 13 ++++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 31f892160..38fe570df 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,4 @@ test/ *.gguf output*.png models* -!taesd-model.gguf *.log \ No newline at end of file diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 58e1f3f25..efaa50387 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -116,6 +116,7 @@ class StableDiffusionGGML { ggml_type wtype, schedule_t schedule, bool control_net_cpu) { + use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -319,10 +320,6 @@ class StableDiffusionGGML { LOG_DEBUG("finished loaded file"); ggml_free(ctx); - if (use_tiny_autoencoder) { - return tae_first_stage.load_from_file(taesd_path, backend); - } - if(control_net_path.size() > 0) { ggml_backend_t cn_backend = NULL; if(control_net_cpu && !ggml_backend_is_cpu(backend)) { @@ -331,7 +328,13 @@ class StableDiffusionGGML { } else { cn_backend = backend; } - return control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */); + if(!control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */)) { + return false; + } + } + + if (use_tiny_autoencoder) { + return tae_first_stage.load_from_file(taesd_path, backend); } return true; } From 7d8f84a3956540a54db90c9d04b56230c18823a3 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Jan 2024 16:03:16 +0800 Subject: [PATCH 28/29] keep naming stype consistent --- control.hpp | 2 +- stable-diffusion.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/control.hpp b/control.hpp index 1c5ed4e72..543998f43 100644 --- a/control.hpp +++ b/control.hpp @@ -684,7 +684,7 @@ struct ControlNet : public GGMLModule { GGMLModule::compute(get_graph, n_threads, NULL); } - void end() { + void free_compute_buffer() { GGMLModule::free_compute_buffer(); ggml_free(control_ctx); ggml_backend_buffer_free(control_buffer); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index efaa50387..ee67dd434 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1027,7 +1027,7 @@ class StableDiffusionGGML { LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); abort(); } - control_net.end(); + control_net.free_compute_buffer(); diffusion_model.free_compute_buffer(); return x; } From 29e612d222cbe1eb9ddc24f141d5e842dd2c824a Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 29 Jan 2024 22:30:22 +0800 Subject: [PATCH 29/29] ignore the embedding name case --- clip.hpp | 12 ++++++------ util.cpp | 33 +++++++++++++++++++++++++++++++++ util.h | 1 + 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/clip.hpp b/clip.hpp index 616b8f119..e04510997 100644 --- a/clip.hpp +++ b/clip.hpp @@ -917,14 +917,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { size_t word_end = str.find(","); std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); embd_name = trim(embd_name); - std::string embd_path = path_join(text_model.embd_dir, embd_name + ".pt"); - if(!file_exists(embd_path)) { - embd_path = path_join(text_model.embd_dir, embd_name + ".ckpt"); + std::string embd_path = get_full_path(text_model.embd_dir, embd_name + ".pt"); + if(embd_path.size() == 0) { + embd_path = get_full_path(text_model.embd_dir, embd_name + ".ckpt"); } - if(!file_exists(embd_path)) { - embd_path = path_join(text_model.embd_dir, embd_name + ".safetensors"); + if(embd_path.size() == 0) { + embd_path = get_full_path(text_model.embd_dir, embd_name + ".safetensors"); } - if(file_exists(embd_path)) { + if(embd_path.size() > 0) { if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { if(word_end != std::string::npos) { str = str.substr(word_end); diff --git a/util.cpp b/util.cpp index e8b9e4250..4445f6c5e 100644 --- a/util.cpp +++ b/util.cpp @@ -72,6 +72,20 @@ bool is_directory(const std::string& path) { return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY)); } +std::string get_full_path(const std::string& dir, const std::string& filename) { + std::string full_path = dir + "\\" + filename; + + WIN32_FIND_DATA find_file_data; + HANDLE hFind = FindFirstFile(full_path.c_str(), &find_file_data); + + if (hFind != INVALID_HANDLE_VALUE) { + FindClose(hFind); + return full_path; + } else { + return ""; + } +} + #else // Unix #include #include @@ -86,6 +100,25 @@ bool is_directory(const std::string& path) { return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); } +std::string get_full_path(const std::string& dir, const std::string& filename) { + DIR* dp = opendir(dir.c_str()); + + if (dp != nullptr) { + struct dirent* entry; + + while ((entry = readdir(dp)) != nullptr) { + if (strcasecmp(entry->d_name, filename.c_str()) == 0) { + closedir(dp); + return dir + "/" + entry->d_name; + } + } + + closedir(dp); + } + + return ""; +} + #endif // get_num_physical_cores is copy from diff --git a/util.h b/util.h index ca0830e12..c1b035f17 100644 --- a/util.h +++ b/util.h @@ -15,6 +15,7 @@ void replace_all_chars(std::string& str, char target, char replacement); bool file_exists(const std::string& filename); bool is_directory(const std::string& path); +std::string get_full_path(const std::string& dir, const std::string& filename); std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str);