#include <iostream>
#include <jpeglib.h>
#include <fstream>
#include <cstring>

bool ReadYuvFromFile(const std::string& filename, uint8_t* buffer, int size) {
    std::ifstream file(filename, std::ios::binary);
    if (!file) {
        return false;
    }
    file.read(reinterpret_cast<char*>(buffer), size);
    return file.gcount() == size;
}

int Nv12ToJpgFile(const char *pFileName, uint8_t* pYUVBuffer, const int nWidth, const int nHeight)
{
    struct jpeg_compress_struct cinfo;
    struct jpeg_error_mgr jerr;
    JSAMPROW row_pointer[1];  
    FILE * pJpegFile = NULL;
    unsigned char *yuvbuf = NULL;
    unsigned char *ybase = NULL, *ubase = NULL;
    int i=0, j=0;
    int idx=0;

    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_compress(&cinfo);

    if ((pJpegFile = fopen(pFileName, "wb")) == NULL)
    {    
        return -1;
    }    
    jpeg_stdio_dest(&cinfo, pJpegFile);

    // image width and height, in pixels
    cinfo.image_width      = nWidth;
    cinfo.image_height     = nHeight;    
    cinfo.input_components = 3;    // # of color components per pixel
    cinfo.in_color_space   = JCS_YCbCr;  //colorspace of input image
    jpeg_set_defaults(&cinfo);
    jpeg_set_quality(&cinfo, 75, TRUE);
  
    cinfo.jpeg_color_space = JCS_YCbCr;
    cinfo.comp_info[0].h_samp_factor = 2;
    cinfo.comp_info[0].v_samp_factor = 2;
  
    jpeg_start_compress(&cinfo, TRUE);
  
    if(NULL == (yuvbuf=(unsigned char *)malloc(nWidth*3)))
    {
        return -1;
    }
    memset(yuvbuf, 0, nWidth*3);
  
    ybase=pYUVBuffer;
    ubase=pYUVBuffer+nWidth*nHeight;
    while (cinfo.next_scanline < cinfo.image_height)
    {
        idx=0;
        for(i=0;i<nWidth;i++)
        {   
            yuvbuf[idx++]=ybase[i + j * nWidth];
            yuvbuf[idx++]=ubase[j/2 * nWidth+(i/2)*2];
            yuvbuf[idx++]=ubase[j/2 * nWidth+(i/2)*2+1];
        }  
        row_pointer[0] = yuvbuf;
        jpeg_write_scanlines(&cinfo, row_pointer, 1);
        j++;
    }
    jpeg_finish_compress( &cinfo);

    jpeg_destroy_compress(&cinfo);
    fclose(pJpegFile);

    return 0;    
}

bool Nv12ToJpg(const uint8_t* pYUVBuffer, const int nWidth, const int nHeight, std::string& sJpeg)
{
    if (!pYUVBuffer)
    {
        return false;
    }

    struct jpeg_compress_struct cinfo;
    struct jpeg_error_mgr jerr;
    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_compress(&cinfo);

    size_t outSize = 0;
    uint8_t *buffer = NULL;
    jpeg_mem_dest(&cinfo, &buffer, &outSize);

    cinfo.image_width      = nWidth;
    cinfo.image_height     = nHeight;
    cinfo.input_components = 3;
    cinfo.in_color_space   = JCS_YCbCr;

    jpeg_set_defaults(&cinfo);
    jpeg_set_quality(&cinfo, 75, TRUE);

    cinfo.jpeg_color_space = JCS_YCbCr;
    cinfo.comp_info[0].h_samp_factor = 2;
    cinfo.comp_info[0].v_samp_factor = 2;

    jpeg_start_compress(&cinfo, TRUE);

    int rowStride = nWidth * 3;
    JSAMPROW rowPointer = new uint8_t[rowStride]();
    if(NULL == rowPointer)
    {
        return false;
    }
    int j = 0;
    int idx = 0;
    const uint8_t *ybase = pYUVBuffer;
    const uint8_t *ubase = pYUVBuffer + nWidth * nHeight;
    while (cinfo.next_scanline < cinfo.image_height)
    {
        idx=0;
        for(int i=0;i<nWidth;i++)
        {   
            rowPointer[idx++] = ybase[i + j * nWidth];
            rowPointer[idx++] = ubase[j/2 * nWidth+(i/2)*2];
            rowPointer[idx++] = ubase[j/2 * nWidth+(i/2)*2+1];
        }
        jpeg_write_scanlines(&cinfo, &rowPointer, 1);
        ++j;
    }

    jpeg_finish_compress(&cinfo);
    std::string output(reinterpret_cast<char*>(buffer), outSize);
    jpeg_destroy_compress(&cinfo);
    if(buffer != NULL)
    {
        delete[] buffer;
    }
    if(rowPointer != NULL)
    {
        delete[] rowPointer;
    }
    sJpeg.swap(output);

    return true;
}

int main() 
{
    const int width = 1280;
    const int height = 720;
    uint8_t* yuvImage = new uint8_t[width * height * 3 / 2];
    std::string jpegOutput;

    int size = width * height * 3 / 2;
    if (!ReadYuvFromFile("in.yuv", yuvImage, width * height * 3 / 2)) {
        std::cout << "Failed to read YUV file." << std::endl;
        delete[] yuvImage; // 释放内存
        return -1;
    }

    // Nv12ToJpgFile("out1.jpg", yuvImage, width, height);
    if (Nv12ToJpg(yuvImage, width, height, jpegOutput)) 
    {
        std::ofstream outFile("out.jpg", std::ios::binary);
        outFile.write(jpegOutput.data(), jpegOutput.size());
        std::cout << "JPEG length: " << jpegOutput.size() << " bytes" << std::endl;
    } else {
        std::cout << "JPEG failed." << std::endl;
    }

    delete[] yuvImage; // 释放内存
    return 0;
}