
/*
 *  ARP response injector
 *  Copyright (c) 2006, Hector Martin <hector@marcansoft.com>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <pcap.h>
#include <errno.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <asm/types.h>
#include <unistd.h>
#include <signal.h>
#include <linux/if_packet.h>
#include <linux/if_ether.h>
#include <linux/if_arp.h>


// MAC stuff borrowed from hostapd/wpa_supplicant
static int hex2num(char c)
{
	if (c >= '0' && c <= '9')
		return c - '0';
	if (c >= 'a' && c <= 'f')
		return c - 'a' + 10;
	if (c >= 'A' && c <= 'F')
		return c - 'A' + 10;
	return -1;
}

static int hex2byte(const char *hex)
{
	int a, b;
	a = hex2num(*hex++);
	if (a < 0)
		return -1;
	b = hex2num(*hex++);
	if (b < 0)
		return -1;
	return (a << 4) | b;
}

static char nibble2hex(int n) {
	if(n<10) return n+'0';
	return n+'A'-10;
}

int hwaddr_aton(const char *txt, unsigned char *addr)
{
	int i;

	for (i = 0; i < 6; i++) {
		int a, b;

		a = hex2num(*txt++);
		if (a < 0)
			return -1;
		b = hex2num(*txt++);
		if (b < 0)
			return -1;
		*addr++ = (a << 4) | b;
		if (i < 5 && *txt++ != ':')
			return -1;
	}

	return 0;
}

char *hwaddr_ntoa(unsigned char *mac) {
	static char s_mac[18];
	char *ptr = s_mac;
	int i;
	for(i=0; i<6; i++) {
		*(ptr++) = nibble2hex((mac[i]&0xF0)>>4);
		*(ptr++) = nibble2hex(mac[i]&0x0F);
		if(i<5) *ptr++ = ':';
	}
	*ptr = 0;
	return s_mac;
}

char *arp_types[]={
		"Null", "ARP Request", "ARP Reply", "RARP Request", "RARP Reply","","","",
		"InARP Request","InARP Reply","(ATM)ARP NAK"
};

struct arp_eth_payload {
	unsigned char s_mac[6];
	struct in_addr s_ip;
	unsigned char t_mac[6];
	struct in_addr t_ip;
} __attribute__((packed));

void main_loop(const char *iface, unsigned char *dsmac, struct in_addr *dsip, unsigned char *srcmac) {
	char errbuf[PCAP_ERRBUF_SIZE];
	pcap_t* descr;
	const unsigned char *packet;
	struct pcap_pkthdr hdr;     /* pcap.h */
	struct ethhdr *eptr;  /* net/ethernet.h */
	int i;

	/*
		Open up interface for sniffing 
		Note: we shouldn't need promiscuous mode here.
	*/

	descr = pcap_open_live(iface,BUFSIZ,0,-1,errbuf);

	if(descr == NULL)
	{
		fprintf(stderr,"pcap_open_live(): %s\n",errbuf);
		exit(1);
	}

	printf("\nListening on %s...\n",iface);

	/* Grab packets */
	while(1) {
		packet = pcap_next(descr,&hdr);
		if(packet == NULL)
		{
			/* screwy packets */
			fprintf(stderr, "Could not grab packet\n");
			exit(1);
		}
		//printf("Grabbed packet of length %d\n",hdr.len);
		// We're assuming it's Ethernet here.
		eptr = (struct ethhdr *) packet;
		if(ntohs(eptr->h_proto) == ETH_P_ARP) {
			struct arphdr *ahdr = (struct arphdr *)(eptr+1);
			printf("\nARP packet received (len: %d)\n",hdr.len);
			printf(" Source MAC: %s\n",hwaddr_ntoa(eptr->h_dest));
			printf(" Dest MAC:   %s\n",hwaddr_ntoa(eptr->h_source));
			if((ntohs(ahdr->ar_hrd)!=ARPHRD_ETHER) || (ntohs(ahdr->ar_pro)!=ETH_P_IP)) {
				printf(" Unknown hardware and protocol address formats: %d %d",ntohs(ahdr->ar_hrd),ntohs(ahdr->ar_pro));
			} else {
				struct arp_eth_payload *arpeth = (struct arp_eth_payload *)(ahdr+1);
				printf(" Address types: Ethernet -> IP\n");
				if(ntohs(ahdr->ar_op)<=10) {
					printf("  Opcode: %d %s\n",ntohs(ahdr->ar_op),arp_types[ntohs(ahdr->ar_op)]);
				} else {
					printf("  Opcode: %d\n",ntohs(ahdr->ar_op));
				}
				printf("  From: %s - %s\n",hwaddr_ntoa(arpeth->s_mac),inet_ntoa(arpeth->s_ip));
				printf("  To:   %s - %s\n",hwaddr_ntoa(arpeth->t_mac),inet_ntoa(arpeth->t_ip));

				if((ntohs(ahdr->ar_op)==ARPOP_REQUEST) && memcmp(&arpeth->t_ip,dsip,sizeof(struct in_addr))==0) {
					printf("   REQUEST MATCH! Injecting ARP reply...\n");

					int packet_size=sizeof(struct ethhdr)+sizeof(struct arphdr)+sizeof(struct arp_eth_payload);
					char *pkt_injected = malloc(packet_size);
					struct ethhdr *inject_ether = (struct ethhdr *)pkt_injected;
					struct arphdr *inject_arp = (struct arphdr *)(inject_ether+1);
					struct arp_eth_payload *inject_arpeth = (struct arp_eth_payload *)(inject_arp+1);
					struct sockaddr_ll socket_address;
					struct ifreq ifr;

					// Build packet
					memset(pkt_injected,0,packet_size);

					memcpy(inject_ether->h_source,srcmac,6);
					memcpy(inject_ether->h_dest,arpeth->s_mac,6);
					inject_ether->h_proto = htons(ETH_P_ARP);
					inject_arp->ar_hrd = htons(ARPHRD_ETHER);
					inject_arp->ar_pro = htons(ETH_P_IP);
					inject_arp->ar_hln = 6;
					inject_arp->ar_pln = 4;
					inject_arp->ar_op = htons(ARPOP_REPLY);
					memcpy(inject_arpeth->s_mac,dsmac,6);
					memcpy(&inject_arpeth->s_ip,dsip,4);
					memcpy(inject_arpeth->t_mac,arpeth->s_mac,6);
					memcpy(&inject_arpeth->t_ip,&arpeth->s_ip,4);
					
					socket_address.sll_family = PF_PACKET;	
					socket_address.sll_protocol = htons(ETH_P_IP); //ignored
					socket_address.sll_hatype = ARPHRD_ETHER;
					socket_address.sll_pkttype  = PACKET_HOST;
					socket_address.sll_halen    = ETH_ALEN;		
					memcpy(socket_address.sll_addr,arpeth->s_mac,6);
					socket_address.sll_addr[6] = 0x00;/*not used*/
					socket_address.sll_addr[7] = 0x00;/*not used*/



					int packet_socket = socket(PF_PACKET, SOCK_RAW, htons(ETH_P_ALL)); //I don't care about receiving packets through here.
					if (packet_socket == -1) {
						fprintf(stderr, "Could not create socket\n");
						exit(1);
					}
					/*retrieve ethernet interface index*/
					strcpy(ifr.ifr_name, iface);
					if (ioctl(packet_socket, SIOCGIFINDEX, &ifr) == -1) {
						fprintf(stderr, "Could not get interface index\n");
						exit(1);
					}
					socket_address.sll_ifindex = ifr.ifr_ifindex;

					int result = sendto(packet_socket,pkt_injected,packet_size,0,(struct sockaddr*)&socket_address,sizeof(socket_address));
					if(result == -1) {
						fprintf(stderr, "Could not send packet\n");
						perror("sendto");
						exit(1);
					}
					close(packet_socket);

				}
			}
		}
	}
}

int main(int argc, char **argv) {
	char *iface;
	char *s_mac;
	char *s_snd_mac;
	char *s_ip;
	struct in_addr ip;
	unsigned char mac[6];
	unsigned char snd_mac[6];

	if((argc != 4) && (argc != 5)) {
		fprintf(stderr,"Usage: dsarp iface DS_MAC DS_IP [Sender_MAC]\n");
		return 1;
	}

	iface = argv[1];
	s_mac = argv[2];
	s_ip = argv[3];
	s_snd_mac = argv[4];

	if(argc == 4)
		s_snd_mac = s_mac;
	

	if(inet_aton(s_ip,&ip)==0) {
		fprintf(stderr,"Invalid IP address '%s'\n",s_ip);
		return 1;
	}
	if(hwaddr_aton(s_mac,mac)!=0) {
		fprintf(stderr,"Invalid MAC address '%s'\n",s_mac);
		return 1;
	}
	if(hwaddr_aton(s_snd_mac,snd_mac)!=0) {
		fprintf(stderr,"Invalid MAC address '%s'\n",s_snd_mac);
		return 1;
	}

	printf("Interface: %s\n",iface);
	printf("DS IP: %s\n",inet_ntoa(ip));
	printf("DS MAC: %s\n",hwaddr_ntoa(mac));
	printf("Sender MAC: %s\n",hwaddr_ntoa(snd_mac));
	
	main_loop(iface,mac,&ip,snd_mac);

	return 0;
}

