/** 
 * @file llstacktrace.cpp
 * @brief stack tracing functionality
 *
 * $LicenseInfo:firstyear=2001&license=viewerlgpl$
 * Second Life Viewer Source Code
 * Copyright (C) 2010, Linden Research, Inc.
 * 
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2.1 of the License only.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
 * 
 * Linden Research, Inc., 945 Battery Street, San Francisco, CA  94111  USA
 * $/LicenseInfo$
 */

#include "linden_common.h"
#include "llstacktrace.h"

#ifdef LL_WINDOWS

#include <iostream>
#include <sstream>

#include "llwin32headerslean.h"
#pragma warning (push)
#pragma warning (disable:4091) // a microsoft header has warnings. Very nice.
#include <dbghelp.h>
#pragma warning (pop)

typedef USHORT NTAPI RtlCaptureStackBackTrace_Function(
    IN ULONG frames_to_skip,
    IN ULONG frames_to_capture,
    OUT PVOID *backtrace,
    OUT PULONG backtrace_hash);

static RtlCaptureStackBackTrace_Function* const RtlCaptureStackBackTrace_fn =
   (RtlCaptureStackBackTrace_Function*)
   GetProcAddress(GetModuleHandleA("ntdll.dll"), "RtlCaptureStackBackTrace");

bool ll_get_stack_trace(std::vector<std::string>& lines)
{
	const S32 MAX_STACK_DEPTH = 32;
	const S32 STRING_NAME_LENGTH = 200;
	const S32 FRAME_SKIP = 2;
	static BOOL symbolsLoaded = false;
	static BOOL firstCall = true;

	HANDLE hProc = GetCurrentProcess();

	// load the symbols if they're not loaded
	if(!symbolsLoaded && firstCall)
	{
		symbolsLoaded = SymInitialize(hProc, NULL, true);
		firstCall = false;
	}

	// if loaded, get the call stack
	if(symbolsLoaded)
	{
		// create the frames to hold the addresses
		void* frames[MAX_STACK_DEPTH];
		memset(frames, 0, sizeof(void*)*MAX_STACK_DEPTH);
		S32 depth = 0;

		// get the addresses
		depth = RtlCaptureStackBackTrace_fn(FRAME_SKIP, MAX_STACK_DEPTH, frames, NULL);

		IMAGEHLP_LINE64 line;
		memset(&line, 0, sizeof(IMAGEHLP_LINE64));
		line.SizeOfStruct = sizeof(IMAGEHLP_LINE64);

		// create something to hold address info
		PIMAGEHLP_SYMBOL64 pSym;
		pSym = (PIMAGEHLP_SYMBOL64)malloc(sizeof(IMAGEHLP_SYMBOL64) + STRING_NAME_LENGTH);
		memset(pSym, 0, sizeof(IMAGEHLP_SYMBOL64) + STRING_NAME_LENGTH);
		pSym->MaxNameLength = STRING_NAME_LENGTH;
		pSym->SizeOfStruct = sizeof(IMAGEHLP_SYMBOL64);

		// get address info for each address frame
		// and store
		for(S32 i=0; i < depth; i++)
		{
			std::stringstream stack_line;
			BOOL ret;

			DWORD64 addr = (DWORD64)frames[i];
			ret = SymGetSymFromAddr64(hProc, addr, 0, pSym);
			if(ret)
			{
				stack_line << pSym->Name << " ";
			}

			DWORD dummy;
			ret = SymGetLineFromAddr64(hProc, addr, &dummy, &line);
			if(ret)
			{
				std::string file_name = line.FileName;
				std::string::size_type index = file_name.rfind("\\");
				stack_line << file_name.substr(index + 1, file_name.size()) << ":" << line.LineNumber; 
			}

			lines.push_back(stack_line.str());
		}
		
		free(pSym);

		// TODO: figure out a way to cleanup symbol loading
		// Not hugely necessary, however.
		//SymCleanup(hProc);
		return true;
	}
	else
	{
		lines.push_back("Stack Trace Failed.  PDB symbol info not loaded");
	}

	return false;
}

void ll_get_stack_trace_internal(std::vector<std::string>& lines)
{
	const S32 MAX_STACK_DEPTH = 100;
	const S32 STRING_NAME_LENGTH = 256;

	HANDLE process = GetCurrentProcess();
	SymInitialize( process, NULL, TRUE );

	void *stack[MAX_STACK_DEPTH];

	unsigned short frames = RtlCaptureStackBackTrace_fn( 0, MAX_STACK_DEPTH, stack, NULL );
	SYMBOL_INFO *symbol = (SYMBOL_INFO*)calloc(sizeof(SYMBOL_INFO) + STRING_NAME_LENGTH * sizeof(char), 1);
	symbol->MaxNameLen = STRING_NAME_LENGTH-1;
	symbol->SizeOfStruct = sizeof(SYMBOL_INFO);

	for(unsigned int i = 0; i < frames; i++) 
	{
		SymFromAddr(process, (DWORD64)(stack[i]), 0, symbol);
		lines.push_back(symbol->Name);
	}

	free( symbol );
}

#else

bool ll_get_stack_trace(std::vector<std::string>& lines)
{
	return false;
}

void ll_get_stack_trace_internal(std::vector<std::string>& lines)
{

}

#endif